diff --git a/.coveragerc b/.coveragerc index 6449c646a..04a26640d 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,6 +7,6 @@ omit = [report] exclude_lines = - :nocov: + :nocov: partial_branches = :nobr: diff --git a/.gitignore b/.gitignore index 5a0095c91..6e98694aa 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ __pycache__/ # coverage /.coverage /htmlcov +/coverage.xml # tests /tests/spec_*/ diff --git a/amaranth/hdl/_ir.py b/amaranth/hdl/_ir.py index 075991a22..f9bec3da6 100644 --- a/amaranth/hdl/_ir.py +++ b/amaranth/hdl/_ir.py @@ -71,6 +71,7 @@ def __init__(self, *, src_loc=None): self.src_loc = src_loc self.origins = None self.domains_propagated_up = {} + self.domain_renames = {} def add_domains(self, *domains): for domain in flatten(domains): @@ -655,6 +656,19 @@ def _assign_names(self, fragment: Fragment, hierarchy: "tuple[str]"): subfragment_name = _add_name(frag_info.assigned_names, subfragment_name) self._assign_names(subfragment, hierarchy=(*hierarchy, subfragment_name)) + def lookup_domain(self, domain, context): + if domain == "comb": + raise KeyError("comb") + if context is not None: + try: + fragment = self.elaboratables[context] + except KeyError: + raise ValueError(f"Elaboratable {context!r} is not a part of the design") + else: + fragment = self.fragment + domain = fragment.domain_renames.get(domain, domain) + return fragment.domains[domain] + ############################################################################################### >:3 diff --git a/amaranth/hdl/_xfrm.py b/amaranth/hdl/_xfrm.py index c1d12a1d3..062cddc72 100644 --- a/amaranth/hdl/_xfrm.py +++ b/amaranth/hdl/_xfrm.py @@ -252,6 +252,9 @@ def map_statements(self, fragment, new_fragment): for domain, statements in fragment.statements.items(): new_fragment.add_statements(domain, statements) + def map_domain_renames(self, fragment, new_fragment): + new_fragment.domain_renames = dict(fragment.domain_renames) + def map_memory_ports(self, fragment, new_fragment): if hasattr(self, "on_value"): for port in new_fragment._read_ports: @@ -318,6 +321,7 @@ def on_fragment(self, fragment): self.map_subfragments(fragment, new_fragment) self.map_domains(fragment, new_fragment) self.map_statements(fragment, new_fragment) + self.map_domain_renames(fragment, new_fragment) return new_fragment def __call__(self, value, *, src_loc_at=0): @@ -513,6 +517,15 @@ def map_statements(self, fragment, new_fragment): map(self.on_statement, statements) ) + def map_domain_renames(self, fragment, new_fragment): + new_fragment.domain_renames = { + src: self.domain_map.get(dst, dst) + for src, dst in fragment.domain_renames.items() + } + for src, dst in self.domain_map.items(): + if src not in new_fragment.domain_renames: + new_fragment.domain_renames[src] = dst + def map_memory_ports(self, fragment, new_fragment): super().map_memory_ports(fragment, new_fragment) for port in new_fragment._read_ports: diff --git a/amaranth/sim/__init__.py b/amaranth/sim/__init__.py index c239c52ca..e65c8a6b1 100644 --- a/amaranth/sim/__init__.py +++ b/amaranth/sim/__init__.py @@ -1,4 +1,8 @@ from .core import * -__all__ = ["Settle", "Delay", "Tick", "Passive", "Active", "Simulator"] +__all__ = [ + "DomainReset", "BrokenTrigger", "Simulator", + # deprecated + "Settle", "Delay", "Tick", "Passive", "Active", +] diff --git a/amaranth/sim/_async.py b/amaranth/sim/_async.py new file mode 100644 index 000000000..9c22ab27b --- /dev/null +++ b/amaranth/sim/_async.py @@ -0,0 +1,293 @@ +import typing +import operator +from contextlib import contextmanager + +from ..hdl import * +from ..hdl._ast import Slice +from ._base import BaseProcess, BaseEngine + + +__all__ = [ + "DomainReset", "BrokenTrigger", + "SampleTrigger", "ChangedTrigger", "EdgeTrigger", "DelayTrigger", + "TriggerCombination", "TickTrigger", + "SimulatorContext", "ProcessContext", "TestbenchContext", "AsyncProcess", +] + + +class DomainReset(Exception): + """Exception raised when the domain of a a tick trigger that is repeatedly awaited has its + reset asserted.""" + + +class BrokenTrigger(Exception): + """Exception raised when a trigger that is repeatedly awaited in an `async for` loop has + a matching event occur while the body of the `async for` loop is executing.""" + + +class SampleTrigger: + def __init__(self, value): + self.value = Value.cast(value) + if isinstance(value, ValueCastable): + self.shape = value.shape() + else: + self.shape = self.value.shape() + + +class ChangedTrigger: + def __init__(self, signal): + cast_signal = Value.cast(signal) + if not isinstance(cast_signal, Signal): + raise TypeError(f"Change trigger can only be used with a signal, not {signal!r}") + self.shape = signal.shape() + self.signal = cast_signal + + @property + def value(self): + return self.signal + + +class EdgeTrigger: + def __init__(self, signal, polarity): + cast_signal = Value.cast(signal) + if isinstance(cast_signal, Signal) and len(cast_signal) == 1: + self.signal, self.bit = cast_signal, 0 + elif (isinstance(cast_signal, Slice) and + len(cast_signal) == 1 and + isinstance(cast_signal.value, Signal)): + self.signal, self.bit = cast_signal.value, cast_signal.start + else: + raise TypeError(f"Edge trigger can only be used with a single-bit signal or " + f"a single-bit slice of a signal, not {signal!r}") + if polarity not in (0, 1): + raise ValueError(f"Edge trigger polarity must be 0 or 1, not {polarity!r}") + self.polarity = polarity + + +class DelayTrigger: + def __init__(self, interval): + self.interval_fs = round(float(interval) * 1e15) + + +class TriggerCombination: + def __init__(self, engine: BaseEngine, process: BaseProcess, *, + triggers: 'tuple[DelayTrigger|ChangedTrigger|SampleTrigger|EdgeTrigger, ...]' = ()): + self._engine = engine + self._process = process # private but used by engines + self._triggers = triggers # private but used by engines + + def sample(self, *values) -> 'TriggerCombination': + return TriggerCombination(self._engine, self._process, triggers=self._triggers + + tuple(SampleTrigger(value) for value in values)) + + def changed(self, *signals) -> 'TriggerCombination': + return TriggerCombination(self._engine, self._process, triggers=self._triggers + + tuple(ChangedTrigger(signal) for signal in signals)) + + def edge(self, signal, polarity) -> 'TriggerCombination': + return TriggerCombination(self._engine, self._process, triggers=self._triggers + + (EdgeTrigger(signal, polarity),)) + + def posedge(self, signal) -> 'TriggerCombination': + return self.edge(signal, 1) + + def negedge(self, signal) -> 'TriggerCombination': + return self.edge(signal, 0) + + def delay(self, interval) -> 'TriggerCombination': + return TriggerCombination(self._engine, self._process, triggers=self._triggers + + (DelayTrigger(interval),)) + + def __await__(self): + trigger = self._engine.add_trigger_combination(self, oneshot=True) + return trigger.__await__() + + async def __aiter__(self): + trigger = self._engine.add_trigger_combination(self, oneshot=False) + while True: + yield await trigger + + +class TickTrigger: + def __init__(self, engine: BaseEngine, process: BaseProcess, *, + domain: ClockDomain, sampled: 'tuple[ValueLike]' = ()): + self._engine = engine + self._process = process + self._domain = domain + self._sampled = sampled + + def sample(self, *values: ValueLike) -> 'TickTrigger': + return TickTrigger(self._engine, self._process, + domain=self._domain, sampled=(*self._sampled, *values)) + + async def until(self, condition: ValueLike): + if not isinstance(condition, ValueLike): + raise TypeError(f"Condition must be a value-like object, not {condition!r}") + tick = self.sample(condition).__aiter__() + done = False + while not done: + clk, rst, *values, done = await tick.__anext__() + if rst: + raise DomainReset + return tuple(values) + + async def repeat(self, count: int): + count = operator.index(count) + if count <= 0: + raise ValueError(f"Repeat count must be a positive integer, not {count!r}") + tick = self.__aiter__() + for _ in range(count): + clk, rst, *values = await tick.__anext__() + if rst: + raise DomainReset + return tuple(values) + + def _collect_trigger(self): + clk_polarity = (1 if self._domain.clk_edge == "pos" else 0) + if self._domain.async_reset and self._domain.rst is not None: + return (TriggerCombination(self._engine, self._process) + .edge(self._domain.clk, clk_polarity) + .edge(self._domain.rst, 1) + .sample(self._domain.rst) + .sample(*self._sampled)) + else: + return (TriggerCombination(self._engine, self._process) + .edge(self._domain.clk, clk_polarity) + .sample(Const(0)) + .sample(Const(0) if self._domain.rst is None else self._domain.rst) + .sample(*self._sampled)) + + def __await__(self): + trigger = self._engine.add_trigger_combination(self._collect_trigger(), oneshot=True) + clk_edge, rst_edge, rst_sample, *values = yield from trigger.__await__() + return (clk_edge, bool(rst_edge or rst_sample), *values) + + async def __aiter__(self): + trigger = self._engine.add_trigger_combination(self._collect_trigger(), oneshot=False) + while True: + clk_edge, rst_edge, rst_sample, *values = await trigger + yield (clk_edge, bool(rst_edge or rst_sample), *values) + + +class SimulatorContext: + def __init__(self, design, engine: BaseEngine, process: BaseProcess): + self._design = design + self._engine = engine + self._process = process + + def delay(self, interval) -> TriggerCombination: + return TriggerCombination(self._engine, self._process).delay(interval) + + def changed(self, *signals) -> TriggerCombination: + return TriggerCombination(self._engine, self._process).changed(*signals) + + def edge(self, signal, polarity) -> TriggerCombination: + return TriggerCombination(self._engine, self._process).edge(signal, polarity) + + def posedge(self, signal) -> TriggerCombination: + return TriggerCombination(self._engine, self._process).posedge(signal) + + def negedge(self, signal) -> TriggerCombination: + return TriggerCombination(self._engine, self._process).negedge(signal) + + @typing.overload + def tick(self, domain: str, *, context: Elaboratable = None) -> TickTrigger: ... # :nocov: + + @typing.overload + def tick(self, domain: ClockDomain) -> TickTrigger: ... # :nocov: + + def tick(self, domain="sync", *, context=None): + if domain == "comb": + raise ValueError("Combinational domain does not have a clock") + if isinstance(domain, ClockDomain): + if context is not None: + raise ValueError("Context cannot be provided if a clock domain is specified " + "directly") + else: + domain = self._design.lookup_domain(domain, context) + return TickTrigger(self._engine, self._process, domain=domain) + + @contextmanager + def critical(self): + try: + old_critical, self._process.critical = self._process.critical, True + yield + finally: + self._process.critical = old_critical + + +class ProcessContext(SimulatorContext): + def get(self, expr: ValueLike) -> 'typing.Never': + raise TypeError("`.get()` cannot be used to sample values in simulator processes; use " + "`.sample()` on a trigger object instead") + + @typing.overload + def set(self, expr: Value, value: int) -> None: ... # :nocov: + + @typing.overload + def set(self, expr: ValueCastable, value: typing.Any) -> None: ... # :nocov: + + def set(self, expr, value): + if isinstance(expr, ValueCastable): + shape = expr.shape() + if isinstance(shape, ShapeCastable): + value = shape.const(value) + value = Const.cast(value).value + self._engine.set_value(expr, value) + + +class TestbenchContext(SimulatorContext): + @typing.overload + def get(self, expr: Value) -> int: ... # :nocov: + + @typing.overload + def get(self, expr: ValueCastable) -> typing.Any: ... # :nocov: + + def get(self, expr): + value = self._engine.get_value(expr) + if isinstance(expr, ValueCastable): + shape = expr.shape() + if isinstance(shape, ShapeCastable): + return shape.from_bits(value) + return value + + @typing.overload + def set(self, expr: Value, value: int) -> None: ... # :nocov: + + @typing.overload + def set(self, expr: ValueCastable, value: typing.Any) -> None: ... # :nocov: + + def set(self, expr, value): + if isinstance(expr, ValueCastable): + shape = expr.shape() + if isinstance(shape, ShapeCastable): + value = shape.const(value) + value = Const.cast(value).value + self._engine.set_value(expr, value) + self._engine.step_design() + + +class AsyncProcess(BaseProcess): + def __init__(self, design, engine, constructor, *, testbench, background): + self.constructor = constructor + if testbench: + self.context = TestbenchContext(design, engine, self) + else: + self.context = ProcessContext(design, engine, self) + self.background = background + + self.reset() + + def reset(self): + self.runnable = True + self.critical = not self.background + self.waits_on = None + self.coroutine = self.constructor(self.context) + + def run(self): + try: + self.waits_on = self.coroutine.send(None) + except StopIteration: + self.critical = False + self.waits_on = None + self.coroutine = None diff --git a/amaranth/sim/_base.py b/amaranth/sim/_base.py index cf0ca6162..c63b95d81 100644 --- a/amaranth/sim/_base.py +++ b/amaranth/sim/_base.py @@ -1,15 +1,14 @@ -__all__ = ["BaseProcess", "BaseSignalState", "BaseMemoryState", "BaseSimulation", "BaseEngine"] +__all__ = ["BaseProcess", "BaseSignalState", "BaseMemoryState", "BaseEngineState", "BaseEngine"] class BaseProcess: __slots__ = () - def __init__(self): - self.reset() + runnable = False + critical = False def reset(self): - self.runnable = False - self.passive = True + raise NotImplementedError # :nocov: def run(self): raise NotImplementedError # :nocov: @@ -24,7 +23,7 @@ class BaseSignalState: curr = NotImplemented next = NotImplemented - def set(self, value): + def update(self, value): raise NotImplementedError # :nocov: @@ -40,7 +39,7 @@ def write(self, addr, value, mask=None): raise NotImplementedError # :nocov: -class BaseSimulation: +class BaseEngineState: def reset(self): raise NotImplementedError # :nocov: @@ -52,37 +51,47 @@ def get_memory(self, memory): slots = NotImplemented - def add_trigger(self, process, signal, *, trigger=None): + def set_delay_waker(self, interval, waker): raise NotImplementedError # :nocov: - def remove_trigger(self, process, signal): + def add_signal_waker(self, signal, waker): raise NotImplementedError # :nocov: - def add_memory_trigger(self, process, memory): + def add_memory_waker(self, memory, waker): raise NotImplementedError # :nocov: - def remove_memory_trigger(self, process, memory): + +class BaseEngine: + @property + def state(self) -> BaseEngineState: raise NotImplementedError # :nocov: - def wait_interval(self, process, interval): + @property + def now(self): raise NotImplementedError # :nocov: + def reset(self): + raise NotImplementedError # :nocov: -class BaseEngine: def add_clock_process(self, clock, *, phase, period): raise NotImplementedError # :nocov: - def add_coroutine_process(self, process, *, default_cmd): + def add_async_process(self, simulator, process): raise NotImplementedError # :nocov: - def add_testbench_process(self, process): + def add_async_testbench(self, simulator, process, *, background): raise NotImplementedError # :nocov: - def reset(self): + def add_trigger_combination(self, combination, *, oneshot): raise NotImplementedError # :nocov: - @property - def now(self): + def get_value(self, expr): + raise NotImplementedError # :nocov: + + def set_value(self, expr, value): + raise NotImplementedError # :nocov: + + def step_design(self): raise NotImplementedError # :nocov: def advance(self): diff --git a/amaranth/sim/_pyclock.py b/amaranth/sim/_pyclock.py index 816ef903c..7b6faf467 100644 --- a/amaranth/sim/_pyclock.py +++ b/amaranth/sim/_pyclock.py @@ -17,18 +17,21 @@ def __init__(self, state, signal, *, phase, period): def reset(self): self.runnable = True - self.passive = True + self.critical = False self.initial = True def run(self): self.runnable = False + def waker(): + self.runnable = True + if self.initial: self.initial = False - self.state.wait_interval(self, self.phase) + self.state.set_delay_waker(self.phase, waker) else: clk_state = self.state.slots[self.slot] - clk_state.set(not clk_state.curr) - self.state.wait_interval(self, self.period // 2) + clk_state.update(not clk_state.curr) + self.state.set_delay_waker(self.period // 2, waker) diff --git a/amaranth/sim/_pycoro.py b/amaranth/sim/_pycoro.py index 8c929b51f..20d749b16 100644 --- a/amaranth/sim/_pycoro.py +++ b/amaranth/sim/_pycoro.py @@ -1,136 +1,130 @@ import inspect +from .._utils import deprecated from ..hdl import * -from ..hdl._ast import Statement, Assign, SignalSet, ValueCastable -from .core import Tick, Settle, Delay, Passive, Active -from ._base import BaseProcess, BaseMemoryState -from ._pyeval import eval_value, eval_assign +from ..hdl._ast import Assign, ValueCastable -__all__ = ["PyCoroProcess"] +__all__ = ["Command", "Settle", "Delay", "Tick", "Passive", "Active", "PyCoroProcess"] -class PyCoroProcess(BaseProcess): - def __init__(self, state, domains, constructor, *, default_cmd=None, testbench=False, - on_command=None): - self.state = state - self.domains = domains - self.constructor = constructor - self.default_cmd = default_cmd - self.testbench = testbench - self.on_command = on_command +class Command: + pass - self.reset() - def reset(self): - self.runnable = True - self.passive = False +class Settle(Command): + @deprecated("The `Settle` command is deprecated per RFC 27. Use `add_testbench` to write " + "testbenches; in them, an equivalent of `yield Settle()` is performed " + "automatically.") + def __init__(self): + pass - self.coroutine = self.constructor() - self.waits_on = SignalSet() + def __repr__(self): + return "(settle)" - def src_loc(self): - coroutine = self.coroutine - if coroutine is None: - return None - while coroutine.gi_yieldfrom is not None and inspect.isgenerator(coroutine.gi_yieldfrom): - coroutine = coroutine.gi_yieldfrom - if inspect.isgenerator(coroutine): - frame = coroutine.gi_frame - if inspect.iscoroutine(coroutine): - frame = coroutine.cr_frame - return f"{inspect.getfile(frame)}:{inspect.getlineno(frame)}" - def add_trigger(self, signal, trigger=None): - self.state.add_trigger(self, signal, trigger=trigger) - self.waits_on.add(signal) +class Delay(Command): + def __init__(self, interval=None): + self.interval = None if interval is None else float(interval) - def clear_triggers(self): - for signal in self.waits_on: - self.state.remove_trigger(self, signal) - self.waits_on.clear() + def __repr__(self): + if self.interval is None: + return "(delay ε)" + else: + return f"(delay {self.interval * 1e6:.3}us)" - def run(self): - if self.coroutine is None: - return - self.clear_triggers() +class Tick(Command): + def __init__(self, domain="sync"): + if not isinstance(domain, (str, ClockDomain)): + raise TypeError("Domain must be a string or a ClockDomain instance, not {!r}" + .format(domain)) + assert domain != "comb" + self.domain = domain + + def __repr__(self): + return f"(tick {self.domain})" + + +class Passive(Command): + def __repr__(self): + return "(passive)" + + +class Active(Command): + def __repr__(self): + return "(active)" + + +def coro_wrapper(process, *, testbench, default_cmd=None): + async def inner(context): + def src_loc(coroutine): + if coroutine is None: + return None + while coroutine.gi_yieldfrom is not None and inspect.isgenerator(coroutine.gi_yieldfrom): + coroutine = coroutine.gi_yieldfrom + if inspect.isgenerator(coroutine): + frame = coroutine.gi_frame + if inspect.iscoroutine(coroutine): + frame = coroutine.cr_frame + return f"{inspect.getfile(frame)}:{inspect.getlineno(frame)}" + + coroutine = process() response = None exception = None while True: try: if exception is None: - command = self.coroutine.send(response) + command = coroutine.send(response) else: - command = self.coroutine.throw(exception) + command = coroutine.throw(exception) except StopIteration: - self.passive = True - self.coroutine = None - return False # no assignment + return try: if command is None: - command = self.default_cmd + command = default_cmd response = None exception = None - if self.on_command is not None: - self.on_command(self, command) - if isinstance(command, ValueCastable): command = Value.cast(command) if isinstance(command, Value): - response = eval_value(self.state, command) + response = context._engine.get_value(command) elif isinstance(command, Assign): - eval_assign(self.state, command.lhs, eval_value(self.state, command.rhs)) - if self.testbench: - return True # assignment; run a delta cycle + context.set(command.lhs, context._engine.get_value(command.rhs)) elif type(command) is Tick: - domain = command.domain - if isinstance(domain, ClockDomain): - pass - elif domain in self.domains: - domain = self.domains[domain] - else: - raise NameError("Received command {!r} that refers to a nonexistent " - "domain {!r} from process {!r}" - .format(command, command.domain, self.src_loc())) - self.add_trigger(domain.clk, trigger=1 if domain.clk_edge == "pos" else 0) - if domain.rst is not None and domain.async_reset: - self.add_trigger(domain.rst, trigger=1) - return False # no assignments - - elif self.testbench and (command is None or isinstance(command, Settle)): + await context.tick(command.domain) + + elif testbench and (command is None or isinstance(command, Settle)): raise TypeError(f"Command {command!r} is not allowed in testbenches") elif type(command) is Settle: - self.state.wait_interval(self, None) - return False # no assignments + await context.delay(0) elif type(command) is Delay: - # Internal timeline is in 1 fs integeral units, intervals are public API and in floating point - interval = int(command.interval * 1e15) if command.interval is not None else None - self.state.wait_interval(self, interval) - return False # no assignments + await context.delay(command.interval or 0) elif type(command) is Passive: - self.passive = True + context._process.critical = False elif type(command) is Active: - self.passive = False + context._process.critical = True elif command is None: # only possible if self.default_cmd is None raise TypeError("Received default command from process {!r} that was added " "with add_process(); did you mean to use Tick() instead?" - .format(self.src_loc())) + .format(src_loc(coroutine))) else: raise TypeError("Received unsupported command {!r} from process {!r}" - .format(command, self.src_loc())) + .format(command, src_loc(coroutine))) except Exception as exn: response = None exception = exn + + return inner diff --git a/amaranth/sim/_pyeval.py b/amaranth/sim/_pyeval.py index be72cc3fb..d7f0cdf84 100644 --- a/amaranth/sim/_pyeval.py +++ b/amaranth/sim/_pyeval.py @@ -3,6 +3,9 @@ from amaranth.hdl._ir import DriverConflict +__all__ = ["eval_value", "eval_format", "eval_assign"] + + def _eval_matches(test, patterns): if patterns is None: return True @@ -175,7 +178,7 @@ def _eval_assign_inner(sim, lhs, lhs_start, rhs, rhs_len): value &= (1 << len(lhs)) - 1 if lhs._signed and (value & (1 << (len(lhs) - 1))): value |= -1 << (len(lhs) - 1) - sim.slots[slot].set(value) + sim.slots[slot].update(value) elif isinstance(lhs, MemoryData._Row): lhs_stop = lhs_start + rhs_len if lhs_stop > len(lhs): @@ -223,5 +226,6 @@ def _eval_assign_inner(sim, lhs, lhs_start, rhs, rhs_len): else: raise ValueError(f"Value {lhs!r} cannot be assigned") + def eval_assign(sim, lhs, value): - _eval_assign_inner(sim, lhs, 0, value, len(lhs)) \ No newline at end of file + _eval_assign_inner(sim, lhs, 0, value, len(lhs)) diff --git a/amaranth/sim/_pyrtl.py b/amaranth/sim/_pyrtl.py index bbf75eaca..be624fa6b 100644 --- a/amaranth/sim/_pyrtl.py +++ b/amaranth/sim/_pyrtl.py @@ -18,7 +18,7 @@ class PyRTLProcess(BaseProcess): - __slots__ = ("is_comb", "runnable", "passive", "run") + __slots__ = ("is_comb", "runnable", "critical", "run") def __init__(self, *, is_comb): self.is_comb = is_comb @@ -27,7 +27,7 @@ def __init__(self, *, is_comb): def reset(self): self.runnable = self.is_comb - self.passive = True + self.critical = False class _PythonEmitter: @@ -443,10 +443,32 @@ def compile(cls, state, stmt): compiler = cls(state, emitter) compiler(stmt) for signal_index in output_indexes: - emitter.append(f"slots[{signal_index}].set(next_{signal_index})") + emitter.append(f"slots[{signal_index}].update(next_{signal_index})") return emitter.flush() +def comb_waker(process): + def waker(curr, next): + process.runnable = True + return True + return waker + + +def edge_waker(process, polarity): + def waker(curr, next): + if next == polarity: + process.runnable = True + return True + return waker + + +def memory_waker(process): + def waker(): + process.runnable = True + return True + return waker + + class _FragmentCompiler: def __init__(self, state): self.state = state @@ -486,7 +508,7 @@ def __call__(self, fragment): _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts) if isinstance(fragment, MemoryInstance): - self.state.add_memory_trigger(domain_process, fragment._data) + self.state.add_memory_waker(fragment._data, memory_waker(domain_process)) memory_index = self.state.get_memory(fragment._data) rhs = _RHSValueCompiler(self.state, emitter, mode="curr", inputs=inputs) lhs = _LHSValueCompiler(self.state, emitter, rhs=rhs) @@ -500,16 +522,16 @@ def __call__(self, fragment): data = emitter.def_var("read_data", f"slots[{memory_index}].read({addr})") lhs(port._data)(data) + waker = comb_waker(domain_process) for input in inputs: - self.state.add_trigger(domain_process, input) + self.state.add_signal_waker(input, waker) else: domain = fragment.domains[domain_name] - clk_trigger = 1 if domain.clk_edge == "pos" else 0 - self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger) - if domain.rst is not None and domain.async_reset: - rst_trigger = 1 - self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger) + clk_polarity = 1 if domain.clk_edge == "pos" else 0 + self.state.add_signal_waker(domain.clk, edge_waker(domain_process, clk_polarity)) + if domain.async_reset and domain.rst is not None: + self.state.add_signal_waker(domain.rst, edge_waker(domain_process, 1)) for signal in domain_signals: signal_index = self.state.get_signal(signal) @@ -572,7 +594,7 @@ def __call__(self, fragment): for signal in domain_signals: signal_index = self.state.get_signal(signal) - emitter.append(f"slots[{signal_index}].set(next_{signal_index})") + emitter.append(f"slots[{signal_index}].update(next_{signal_index})") # There shouldn't be any exceptions raised by the generated code, but if there are # (almost certainly due to a bug in the code generator), use this environment variable diff --git a/amaranth/sim/core.py b/amaranth/sim/core.py index 936cd5fb3..668870d82 100644 --- a/amaranth/sim/core.py +++ b/amaranth/sim/core.py @@ -7,57 +7,16 @@ from ..hdl._ast import Value, ValueLike from ..hdl._mem import MemoryData from ._base import BaseEngine +from ._async import DomainReset, BrokenTrigger +from ._pycoro import Tick, Settle, Delay, Passive, Active, coro_wrapper -__all__ = ["Settle", "Delay", "Tick", "Passive", "Active", "Simulator"] - - -class Command: - pass - - -class Settle(Command): - @deprecated("The `Settle` command is deprecated per RFC 27. Use `add_testbench` to write " - "testbenches; in them, an equivalent of `yield Settle()` is performed " - "automatically.") - def __init__(self): - pass - - def __repr__(self): - return "(settle)" - - -class Delay(Command): - def __init__(self, interval=None): - self.interval = None if interval is None else float(interval) - - def __repr__(self): - if self.interval is None: - return "(delay ε)" - else: - return f"(delay {self.interval * 1e6:.3}us)" - - -class Tick(Command): - def __init__(self, domain="sync"): - if not isinstance(domain, (str, ClockDomain)): - raise TypeError("Domain must be a string or a ClockDomain instance, not {!r}" - .format(domain)) - assert domain != "comb" - self.domain = domain - - def __repr__(self): - return f"(tick {self.domain})" - - -class Passive(Command): - def __repr__(self): - return "(passive)" - - -class Active(Command): - def __repr__(self): - return "(active)" +__all__ = [ + "DomainReset", "BrokenTrigger", + "Simulator", + # deprecated + "Settle", "Delay", "Tick", "Passive", "Active", +] class Simulator: @@ -84,11 +43,23 @@ def _check_process(self, process): def add_process(self, process): process = self._check_process(process) - def wrapper(): - # Only start a bench process after comb settling, so that the initial values are correct. - yield object.__new__(Settle) - yield from process() - self._engine.add_coroutine_process(wrapper, default_cmd=None) + if inspect.iscoroutinefunction(process): + self._engine.add_async_process(self, process) + else: + def wrapper(): + # Only start a bench process after comb settling, so that the initial values are correct. + yield Active() + yield object.__new__(Settle) + yield from process() + wrap_process = coro_wrapper(wrapper, testbench=False) + self._engine.add_async_process(self, wrap_process) + + def add_testbench(self, process, *, background=False): + if inspect.iscoroutinefunction(process): + self._engine.add_async_testbench(self, process, background=background) + else: + process = coro_wrapper(process, testbench=True) + self._engine.add_async_testbench(self, process, background=background) @deprecated("The `add_sync_process` method is deprecated per RFC 27. Use `add_process` or `add_testbench` instead.") def add_sync_process(self, process, *, domain="sync"): @@ -99,6 +70,7 @@ def wrapper(): generator = process() result = None exception = None + yield Active() yield Tick(domain) while True: try: @@ -114,10 +86,8 @@ def wrapper(): except Exception as e: result = None exception = e - self._engine.add_coroutine_process(wrapper, default_cmd=Tick(domain)) - - def add_testbench(self, process): - self._engine.add_testbench_process(self._check_process(process)) + wrap_process = coro_wrapper(wrapper, testbench=False, default_cmd=Tick(domain)) + self._engine.add_async_process(self, wrap_process) def add_clock(self, period, *, phase=None, domain="sync", if_exists=False): """Add a clock process. diff --git a/amaranth/sim/pysim.py b/amaranth/sim/pysim.py index 64f86f451..da49aa0b0 100644 --- a/amaranth/sim/pysim.py +++ b/amaranth/sim/pysim.py @@ -7,9 +7,9 @@ from ..hdl import * from ..hdl._ast import SignalDict from ._base import * -from ._pyeval import eval_format, eval_value +from ._async import * +from ._pyeval import eval_format, eval_value, eval_assign from ._pyrtl import _FragmentCompiler -from ._pycoro import PyCoroProcess from ._pyclock import PyClockProcess @@ -268,16 +268,6 @@ def update_memory(self, timestamp, memory, addr): var_value = eval_format(self.state, repr) self.vcd_writer.change(vcd_var, timestamp, var_value) - def update_process(self, timestamp, process, command): - try: - vcd_var = self.vcd_process_vars[process] - except KeyError: - return - # Ensure that the waveform viewer displays a change point even if the previous command is - # the same as the next one. - self.vcd_writer.change(vcd_var, timestamp, "") - self.vcd_writer.change(vcd_var, timestamp, repr(command)) - def close(self, timestamp): if self.vcd_writer is not None: self.vcd_writer.close(timestamp) @@ -307,80 +297,80 @@ def close(self, timestamp): self.gtkw_file.close() -class _Timeline: +class _PyTimeline: def __init__(self): self.now = 0 - self.deadlines = dict() + self.wakers = {} def reset(self): self.now = 0 - self.deadlines.clear() + self.wakers.clear() - def at(self, run_at, process): - assert process not in self.deadlines - self.deadlines[process] = run_at - - def delay(self, delay_by, process): - if delay_by is None: - run_at = self.now - else: - run_at = self.now + delay_by - self.at(run_at, process) + def set_waker(self, interval, waker): + self.wakers[waker] = self.now + interval def advance(self): - nearest_processes = set() + nearest_wakers = set() nearest_deadline = None - for process, deadline in self.deadlines.items(): - if deadline is None: - if nearest_deadline is not None: - nearest_processes.clear() - nearest_processes.add(process) - nearest_deadline = self.now - break - elif nearest_deadline is None or deadline <= nearest_deadline: + for waker, deadline in self.wakers.items(): + if nearest_deadline is None or deadline <= nearest_deadline: assert deadline >= self.now if nearest_deadline is not None and deadline < nearest_deadline: - nearest_processes.clear() - nearest_processes.add(process) + nearest_wakers.clear() + nearest_wakers.add(waker) nearest_deadline = deadline - if not nearest_processes: + if not nearest_wakers: return False - for process in nearest_processes: - process.runnable = True - del self.deadlines[process] - self.now = nearest_deadline + for waker in nearest_wakers: + waker() + del self.wakers[waker] + self.now = nearest_deadline return True +def _run_wakers(wakers: list, *args): + # Python doesn't have `.retain()` :( + index = 0 + for waker in wakers: + if waker(*args): + wakers[index] = waker + index += 1 + del wakers[index:] + + class _PySignalState(BaseSignalState): - __slots__ = ("signal", "is_comb", "curr", "next", "waiters", "pending") + __slots__ = ("signal", "is_comb", "curr", "next", "wakers", "pending") def __init__(self, signal, pending): - self.signal = signal + self.signal = signal self.is_comb = False self.pending = pending - self.waiters = {} - self.curr = self.next = signal.init + self.wakers = list() + self.reset() - def set(self, value): - if self.next == value: - return - self.next = value - self.pending.add(self) + def reset(self): + self.curr = self.next = self.signal.init + + def add_waker(self, waker): + assert waker not in self.wakers + self.wakers.append(waker) + + def update(self, value): + if self.next != value: + self.next = value + self.pending.add(self) def commit(self): if self.curr == self.next: return False - self.curr = self.next - awoken_any = False - for process, trigger in self.waiters.items(): - if trigger is None or trigger == self.curr: - process.runnable = awoken_any = True - return awoken_any + _run_wakers(self.wakers, self.curr, self.next) + + self.curr = self.next + return True class _PyMemoryChange: @@ -388,73 +378,65 @@ class _PyMemoryChange: def __init__(self, state, addr): self.state = state - self.addr = addr + self.addr = addr class _PyMemoryState(BaseMemoryState): - __slots__ = ("memory", "data", "write_queue", "waiters", "pending") + __slots__ = ("memory", "data", "write_queue", "wakers", "pending") def __init__(self, memory, pending): - self.memory = memory + self.memory = memory self.pending = pending - self.waiters = {} + self.wakers = list() self.reset() def reset(self): self.data = list(self.memory._init._raw) - self.write_queue = [] + self.write_queue = {} - def commit(self): - if not self.write_queue: - return False - - for addr, value, mask in self.write_queue: - curr = self.data[addr] - value = (value & mask) | (curr & ~mask) - self.data[addr] = value - self.write_queue.clear() - - awoken_any = False - for process in self.waiters: - process.runnable = awoken_any = True - return awoken_any + def add_waker(self, waker): + assert waker not in self.wakers + self.wakers.append(waker) def read(self, addr): - if addr not in range(self.memory.depth): - return 0 - - return self.data[addr] + if addr in range(self.memory.depth): + return self.data[addr] + return 0 def write(self, addr, value, mask=None): - if addr not in range(self.memory.depth): - return - if mask == 0: - return + if addr in range(self.memory.depth): + if addr not in self.write_queue: + self.write_queue[addr] = self.data[addr] + if mask is not None: + value = (value & mask) | (self.write_queue[addr] & ~mask) + self.write_queue[addr] = value + self.pending.add(self) + + def commit(self): + assert self.write_queue # `commit()` is only called if `self` is pending - if mask is None: - mask = (1 << Shape.cast(self.memory.shape).width) - 1 + _run_wakers(self.wakers) - self.write_queue.append((addr, value, mask)) - self.pending.add(self) + changed = False + for addr, value in self.write_queue.items(): + if self.data[addr] != value: + self.data[addr] = value + changed = True + self.write_queue.clear() + return changed -class _PySimulation(BaseSimulation): +class _PyEngineState(BaseEngineState): def __init__(self): - self.timeline = _Timeline() - self.signals = SignalDict() - self.memories = {} - self.slots = [] - self.pending = set() + self.timeline = _PyTimeline() + self.signals = SignalDict() + self.memories = dict() + self.slots = list() + self.pending = set() def reset(self): self.timeline.reset() - for signal, index in self.signals.items(): - state = self.slots[index] - assert isinstance(state, _PySignalState) - state.curr = state.next = signal.init - for index in self.memories.values(): - state = self.slots[index] - assert isinstance(state, _PyMemoryState) + for state in self.slots: state.reset() self.pending.clear() @@ -476,35 +458,21 @@ def get_memory(self, memory): self.memories[memory] = index return index - def add_trigger(self, process, signal, *, trigger=None): - index = self.get_signal(signal) - assert (process not in self.slots[index].waiters or - self.slots[index].waiters[process] == trigger) - self.slots[index].waiters[process] = trigger + def set_delay_waker(self, interval, waker): + self.timeline.set_waker(interval, waker) - def remove_trigger(self, process, signal): - index = self.get_signal(signal) - assert process in self.slots[index].waiters - del self.slots[index].waiters[process] + def add_signal_waker(self, signal, waker): + self.slots[self.get_signal(signal)].add_waker(waker) - def add_memory_trigger(self, process, memory): - index = self.get_memory(memory) - self.slots[index].waiters[process] = None - - def remove_memory_trigger(self, process, memory): - index = self.get_memory(memory) - assert process in self.slots[index].waiters - del self.slots[index].waiters[process] - - def wait_interval(self, process, interval): - self.timeline.delay(interval, process) + def add_memory_waker(self, memory, waker): + self.slots[self.get_memory(memory)].add_waker(waker) def commit(self, changed=None): converged = True for state in self.pending: if changed is not None: if isinstance(state, _PyMemoryState): - for addr, _value, _mask in state.write_queue: + for addr in state.write_queue: changed.add(_PyMemoryChange(state, addr)) elif isinstance(state, _PySignalState): changed.add(state) @@ -516,57 +484,177 @@ def commit(self, changed=None): return converged +class _PyTriggerState: + def __init__(self, engine, combination, pending, *, oneshot): + self._engine = engine + self._combination = combination + self._active = pending + self._oneshot = oneshot + + self._result = None + self._broken = False + self._triggers_hit = set() + self._delay_wakers = dict() + + for trigger in combination._triggers: + if isinstance(trigger, SampleTrigger): + pass # does not cause a wakeup + elif isinstance(trigger, ChangedTrigger): + self.add_changed_waker(trigger) + elif isinstance(trigger, EdgeTrigger): + self.add_edge_waker(trigger) + elif isinstance(trigger, DelayTrigger): + self.add_delay_waker(trigger) + else: + assert False # :nocov: + + def add_changed_waker(self, trigger): + def waker(curr, next): + if self._broken: + return False + self.activate() + return not self._oneshot + self._engine.state.add_signal_waker(trigger.signal, waker) + + def add_edge_waker(self, trigger): + def waker(curr, next): + if self._broken: + return False + curr_bit = (curr >> trigger.bit) & 1 + next_bit = (next >> trigger.bit) & 1 + if curr_bit == next_bit or next_bit != trigger.polarity: + return True # wait until next edge + self._triggers_hit.add(trigger) + self.activate() + return not self._oneshot + self._engine.state.add_signal_waker(trigger.signal, waker) + + def add_delay_waker(self, trigger): + def waker(): + if self._broken: + return + self._triggers_hit.add(trigger) + self.activate() + self._engine.state.set_delay_waker(trigger.interval_fs, waker) + self._delay_wakers[waker] = trigger.interval_fs + + def activate(self): + if self._combination._process.waits_on is self: + self._active.add(self) + else: + self._broken = True + + def run(self): + result = [] + for trigger in self._combination._triggers: + if isinstance(trigger, (SampleTrigger, ChangedTrigger)): + value = self._engine.get_value(trigger.value) + if isinstance(trigger.shape, ShapeCastable): + result.append(trigger.shape.from_bits(value)) + else: + result.append(value) + elif isinstance(trigger, (EdgeTrigger, DelayTrigger)): + result.append(trigger in self._triggers_hit) + else: + assert False # :nocov: + self._result = tuple(result) + + self._combination._process.runnable = True + self._combination._process.waits_on = None + self._triggers_hit.clear() + for waker, interval_fs in self._delay_wakers.items(): + self._engine.state.set_delay_waker(interval_fs, waker) + + def __await__(self): + self._result = None + if self._broken: + raise BrokenTrigger + yield self + if self._broken: + raise BrokenTrigger + return self._result + + class PySimEngine(BaseEngine): def __init__(self, design): - self._state = _PySimulation() - self._timeline = self._state.timeline - self._design = design + + self._state = _PyEngineState() self._processes = _FragmentCompiler(self._state)(self._design.fragment) self._testbenches = [] self._delta_cycles = 0 self._vcd_writers = [] + self._active_triggers = set() - def add_clock_process(self, clock, *, phase, period): - self._processes.add(PyClockProcess(self._state, clock, - phase=phase, period=period)) + @property + def state(self) -> BaseEngineState: + return self._state - def add_coroutine_process(self, process, *, default_cmd): - self._processes.add(PyCoroProcess(self._state, self._design.fragment.domains, process, - default_cmd=default_cmd)) + @property + def now(self): + return self._state.timeline.now - def add_testbench_process(self, process): - self._testbenches.append(PyCoroProcess(self._state, self._design.fragment.domains, process, - testbench=True, on_command=self._debug_process)) + def _now_plus_deltas(self, fs_per_delta): + return self._state.timeline.now + self._delta_cycles * fs_per_delta def reset(self): self._state.reset() for process in self._processes: process.reset() - def _step_rtl(self): - # Performs the two phases of a delta cycle in a loop: + def add_clock_process(self, clock, *, phase, period): + self._processes.add(PyClockProcess(self._state, clock, + phase=phase, period=period)) + + def add_async_process(self, simulator, process): + self._processes.add(AsyncProcess(self._design, self, process, + testbench=False, background=True)) + + def add_async_testbench(self, simulator, process, *, background): + self._testbenches.append(AsyncProcess(self._design, self, process, + testbench=True, background=background)) + + def add_trigger_combination(self, combination, *, oneshot): + return _PyTriggerState(self, combination, self._active_triggers, oneshot=oneshot) + + def get_value(self, expr): + return eval_value(self._state, Value.cast(expr)) + + def set_value(self, expr, value): + assert isinstance(value, int) + return eval_assign(self._state, Value.cast(expr), value) + + def step_design(self): + # Performs the three phases of a delta cycle in a loop: converged = False while not converged: changed = set() if self._vcd_writers else None - # 1. eval: run and suspend every non-waiting process once, queueing signal changes + # 1a. trigger: run every active trigger, sampling values and waking up processes; + for trigger_state in self._active_triggers: + trigger_state.run() + self._active_triggers.clear() + + # 1b. eval: run every runnable processes once, queueing signal changes; for process in self._processes: if process.runnable: process.runnable = False process.run() + if type(process) is AsyncProcess and process.waits_on is not None: + assert type(process.waits_on) is _PyTriggerState, \ + "Async processes may only await simulation triggers" - # 2. commit: apply every queued signal change, waking up any waiting processes + # 2. commit: apply queued signal changes, activating any awaited triggers. converged = self._state.commit(changed) for vcd_writer in self._vcd_writers: - now_plus_deltas = self._now_plus_deltas(vcd_writer) + now_plus_deltas = self._now_plus_deltas(vcd_writer.fs_per_delta) for change in changed: - if isinstance(change, _PySignalState): + if type(change) is _PySignalState: signal_state = change vcd_writer.update_signal(now_plus_deltas, signal_state.signal) - elif isinstance(change, _PyMemoryChange): + elif type(change) is _PyMemoryChange: vcd_writer.update_memory(now_plus_deltas, change.state.memory, change.addr) else: @@ -574,41 +662,33 @@ def _step_rtl(self): self._delta_cycles += 1 - def _debug_process(self, process, command): - for vcd_writer in self._vcd_writers: - now_plus_deltas = self._now_plus_deltas(vcd_writer) - vcd_writer.update_process(now_plus_deltas, process, command) - - self._delta_cycles += 1 - - def _step_tb(self): - # Run processes waiting for an interval to expire (mainly `add_clock_process()``) - self._step_rtl() + def advance(self): + # Run triggers and processes until the simulation converges. + self.step_design() - # Run testbenches waiting for an interval to expire, or for a signal to change state + # Run testbenches that have been awoken in `step_design()` by active triggers. converged = False while not converged: converged = True - # Schedule testbenches in a deterministic, predictable order by iterating a list + # Schedule testbenches in a deterministic order (the one in which they were added). for testbench in self._testbenches: if testbench.runnable: testbench.runnable = False - while testbench.run(): - # Testbench has changed simulation state; run processes triggered by that - converged = False - self._step_rtl() - - def advance(self): - self._step_tb() - self._timeline.advance() - return any(not process.passive for process in (*self._processes, *self._testbenches)) - - @property - def now(self): - return self._timeline.now - - def _now_plus_deltas(self, vcd_writer): - return self._timeline.now + self._delta_cycles * vcd_writer.fs_per_delta + testbench.run() + if type(testbench) is AsyncProcess and testbench.waits_on is not None: + assert type(testbench.waits_on) is _PyTriggerState, \ + "Async testbenches may only await simulation triggers" + converged = False + + # Now that the simulation has converged for the current time, advance the timeline. + self._state.timeline.advance() + + # Check if the simulation has any critical processes or testbenches. + for runnables in (self._processes, self._testbenches): + for runnable in runnables: + if runnable.critical: + return True + return False @contextmanager def write_vcd(self, *, vcd_file, gtkw_file, traces, fs_per_delta): @@ -619,5 +699,5 @@ def write_vcd(self, *, vcd_file, gtkw_file, traces, fs_per_delta): self._vcd_writers.append(vcd_writer) yield finally: - vcd_writer.close(self._now_plus_deltas(vcd_writer)) + vcd_writer.close(self._now_plus_deltas(vcd_writer.fs_per_delta)) self._vcd_writers.remove(vcd_writer) diff --git a/pyproject.toml b/pyproject.toml index 599740ffe..2f861e47a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ examples = [ [tool.pdm.scripts] _.env_file = ".env.toolchain" -test.composite = ["test-code", "test-docs"] +test.composite = ["test-code", "test-docs", "coverage-xml"] test-code.env = {PYTHONWARNINGS = "error"} test-code.cmd = "python -m coverage run -m unittest discover -t . -s tests -v" test-docs.cmd = "sphinx-build -b doctest docs/ docs/_build" @@ -94,3 +94,4 @@ document-linkcheck.cmd = "sphinx-build docs/ docs/_linkcheck/ -b linkcheck" coverage-text.cmd = "python -m coverage report" coverage-html.cmd = "python -m coverage html" +coverage-xml.cmd = "python -m coverage xml" diff --git a/tests/test_hdl_ir.py b/tests/test_hdl_ir.py index a3326e5ca..34d5987f3 100644 --- a/tests/test_hdl_ir.py +++ b/tests/test_hdl_ir.py @@ -8,6 +8,7 @@ from amaranth.hdl._ir import * from amaranth.hdl._mem import * from amaranth.hdl._nir import SignalField, CombinationalCycle +from amaranth.hdl._xfrm import * from amaranth.lib import enum, data @@ -3561,3 +3562,39 @@ def test_cycle(self): r".*test_hdl_ir.py:\d+: signal a bit 0\n" r"$"): build_netlist(Fragment.get(m, None), []) + + +class DomainLookupTestCase(FHDLTestCase): + def test_domain_lookup(self): + m1 = Module() + m1_a = m1.domains.a = ClockDomain("a") + m1_b = m1.domains.b = ClockDomain("b") + m1_c = m1.domains.c = ClockDomain("c") + m2 = Module() + m3 = Module() + m3.d.sync += Print("m3") + m4 = Module() + m4.d.sync += Print("m4") + m4_d = m4.domains.d = ClockDomain("d") + m5 = Module() + m5.d.sync += Print("m5") + m5_d = m5.domains.d = ClockDomain("d") + + m1.submodules.m2 = xm2 = DomainRenamer({"a": "b"})(m2) + m2.submodules.m3 = xm3 = DomainRenamer("a")(m3) + m2.submodules.m4 = xm4 = DomainRenamer("b")(m4) + m2.submodules.m5 = xm5 = DomainRenamer("c")(m5) + + design = Fragment.get(m1, None).prepare() + + self.assertIs(design.lookup_domain("a", m1), m1_a) + self.assertIs(design.lookup_domain("b", m1), m1_b) + self.assertIs(design.lookup_domain("c", m1), m1_c) + self.assertIs(design.lookup_domain("a", xm2), m1_b) + self.assertIs(design.lookup_domain("b", xm2), m1_b) + self.assertIs(design.lookup_domain("c", xm2), m1_c) + self.assertIs(design.lookup_domain("sync", xm3), m1_b) + self.assertIs(design.lookup_domain("sync", xm4), m1_b) + self.assertIs(design.lookup_domain("sync", xm5), m1_c) + self.assertIs(design.lookup_domain("d", xm4), m4_d) + self.assertIs(design.lookup_domain("d", xm5), m5_d) diff --git a/tests/test_sim.py b/tests/test_sim.py index 27eeda44a..678e557ed 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -35,26 +35,25 @@ def assertStatement(self, stmt, inputs, output, init=0): frag.add_statements("comb", stmt) sim = Simulator(frag) - def process(): + async def process(ctx): for isig, input in zip(isigs, inputs): - yield isig.eq(input) - self.assertEqual((yield osig), output.value) + ctx.set(isig, ctx.get(input)) + self.assertEqual(ctx.get(osig), output.value) sim.add_testbench(process) with sim.write_vcd("test.vcd", "test.gtkw", traces=[*isigs, osig]): sim.run() frag = Fragment() sim = Simulator(frag) - def process(): + async def process(ctx): for isig, input in zip(isigs, inputs): - yield isig.eq(input) - yield Delay(0) + ctx.set(isig, ctx.get(input)) if isinstance(stmt, Assign): - yield stmt + ctx.set(stmt.lhs, ctx.get(stmt.rhs)) else: - yield from stmt - yield Delay(0) - self.assertEqual((yield osig), output.value) + for s in stmt: + ctx.set(s.lhs, ctx.get(s.rhs)) + self.assertEqual(ctx.get(osig), output.value) sim.add_testbench(process) with sim.write_vcd("test.vcd", "test.gtkw", traces=[*isigs, osig]): sim.run() @@ -597,18 +596,18 @@ def test_alu_bench(self): self.setUp_alu() with self.assertSimulation(self.m) as sim: sim.add_clock(1e-6) - def process(): - yield self.a.eq(5) - yield self.b.eq(1) - self.assertEqual((yield self.x), 4) - yield Tick() - self.assertEqual((yield self.o), 6) - yield self.s.eq(1) - yield Tick() - self.assertEqual((yield self.o), 4) - yield self.s.eq(2) - yield Tick() - self.assertEqual((yield self.o), 0) + async def process(ctx): + ctx.set(self.a, 5) + ctx.set(self.b, 1) + self.assertEqual(ctx.get(self.x), 4) + await ctx.tick() + self.assertEqual(ctx.get(self.o), 6) + ctx.set(self.s, 1) + await ctx.tick() + self.assertEqual(ctx.get(self.o), 4) + ctx.set(self.s, 2) + await ctx.tick() + self.assertEqual(ctx.get(self.o), 0) sim.add_testbench(process) def setUp_clock_phase(self): @@ -636,7 +635,7 @@ def test_clock_phase(self): sim.add_clock(period, phase=2*period/4, domain="phase180") sim.add_clock(period, phase=3*period/4, domain="phase270") - def proc(): + async def proc(ctx): clocks = [ self.phase0.clk, self.phase90.clk, @@ -644,9 +643,9 @@ def proc(): self.phase270.clk ] for i in range(16): - yield Tick("check") + await ctx.tick("check") for j, c in enumerate(clocks): - self.assertEqual((yield c), self.expected[j][i]) + self.assertEqual(ctx.get(c), self.expected[j][i]) sim.add_process(proc) @@ -663,16 +662,15 @@ def test_multiclock(self): sim.add_clock(1e-6, domain="sys") sim.add_clock(0.3e-6, domain="pix") - def sys_process(): - yield Passive() - yield Tick("sys") - yield Tick("sys") + async def sys_process(ctx): + await ctx.tick("sys") + await ctx.tick("sys") self.fail() - def pix_process(): - yield Tick("pix") - yield Tick("pix") - yield Tick("pix") - sim.add_testbench(sys_process) + async def pix_process(ctx): + await ctx.tick("pix") + await ctx.tick("pix") + await ctx.tick("pix") + sim.add_testbench(sys_process, background=True) sim.add_testbench(pix_process) def setUp_lhs_rhs(self): @@ -698,9 +696,9 @@ def test_run_until(self): m.d.sync += s.eq(0) with self.assertSimulation(m, deadline=100e-6) as sim: sim.add_clock(1e-6) - def process(): + async def process(ctx): for _ in range(101): - yield Delay(1e-6) + await ctx.delay(1e-6) self.fail() sim.add_testbench(process) @@ -710,12 +708,12 @@ def test_run_until_fail(self): m.d.sync += s.eq(0) with self.assertRaises(AssertionError): with self.assertSimulation(m, deadline=100e-6) as sim: - sim.add_clock(1e-6) - def process(): - for _ in range(99): - yield Delay(1e-6) - self.fail() - sim.add_testbench(process) + sim.add_clock(1e-6) + async def process(ctx): + for _ in range(99): + await ctx.delay(1e-6) + self.fail() + sim.add_testbench(process) def test_add_process_wrong(self): with self.assertSimulation(Module()) as sim: @@ -818,13 +816,13 @@ def setUp_memory(self, rd_synchronous=True, rd_transparent=False, wr_granularity def test_memory_init(self): self.setUp_memory() with self.assertSimulation(self.m) as sim: - def process(): - yield self.rdport.addr.eq(1) - yield Tick() - self.assertEqual((yield self.rdport.data), 0x55) - yield self.rdport.addr.eq(2) - yield Tick() - self.assertEqual((yield self.rdport.data), 0x00) + async def process(ctx): + ctx.set(self.rdport.addr, 1) + await ctx.tick() + self.assertEqual(ctx.get(self.rdport.data), 0x55) + ctx.set(self.rdport.addr, 2) + await ctx.tick() + self.assertEqual(ctx.get(self.rdport.data), 0x00) sim.add_clock(1e-6) sim.add_testbench(process) @@ -1443,3 +1441,497 @@ def testbench(): yield c.eq(0) sim.add_testbench(testbench) sim.run() + + def test_sample(self): + m = Module() + m.domains.sync = cd_sync = ClockDomain() + a = Signal(4) + b = Signal(4) + sim = Simulator(m) + + async def bench_a(ctx): + _, _, av, bv = await ctx.tick().sample(a, b) + ctx.set(a, 5) + self.assertEqual(av, 1) + self.assertEqual(bv, 2) + + async def bench_b(ctx): + _, _, av, bv = await ctx.tick().sample(a, b) + ctx.set(b, 6) + self.assertEqual(av, 1) + self.assertEqual(bv, 2) + + async def bench_c(ctx): + ctx.set(a, 1) + ctx.set(b, 2) + ctx.set(cd_sync.clk, 1) + ctx.set(a, 3) + ctx.set(b, 4) + + sim.add_testbench(bench_a) + sim.add_testbench(bench_b) + sim.add_testbench(bench_c) + sim.run() + + def test_latch(self): + q = Signal(4) + d = Signal(4) + g = Signal() + + async def latch(ctx): + async for dv, gv in ctx.changed(d, g): + if gv: + ctx.set(q, dv) + + async def testbench(ctx): + ctx.set(d, 1) + self.assertEqual(ctx.get(q), 0) + ctx.set(g, 1) + self.assertEqual(ctx.get(q), 1) + ctx.set(d, 2) + self.assertEqual(ctx.get(q), 2) + ctx.set(g, 0) + self.assertEqual(ctx.get(q), 2) + ctx.set(d, 3) + self.assertEqual(ctx.get(q), 2) + + sim = Simulator(Module()) + sim.add_process(latch) + sim.add_testbench(testbench) + sim.run() + + def test_edge(self): + a = Signal(4) + b = Signal(4) + + log = [] + + async def monitor(ctx): + async for res in ctx.posedge(a[0]).negedge(a[1]).sample(b): + log.append(res) + + async def testbench(ctx): + ctx.set(b, 8) + ctx.set(a, 0) + ctx.set(b, 9) + ctx.set(a, 1) + ctx.set(b, 10) + ctx.set(a, 2) + ctx.set(b, 11) + ctx.set(a, 3) + ctx.set(b, 12) + ctx.set(a, 4) + ctx.set(b, 13) + ctx.set(a, 6) + ctx.set(b, 14) + ctx.set(a, 5) + + sim = Simulator(Module()) + sim.add_process(monitor) + sim.add_testbench(testbench) + sim.run() + + self.assertEqual(log, [ + (True, False, 9), + (True, False, 11), + (False, True, 12), + (True, True, 14) + ]) + + def test_delay(self): + log = [] + + async def monitor(ctx): + async for res in ctx.delay(1).delay(2).delay(1): + log.append(res) + + async def testbench(ctx): + await ctx.delay(4) + + sim = Simulator(Module()) + sim.add_process(monitor) + sim.add_testbench(testbench) + sim.run() + + self.assertEqual(log, [ + (True, False, True), + (True, False, True), + (True, False, True), + (True, False, True), + ]) + + + def test_timeout(self): + a = Signal() + + log = [] + + async def monitor(ctx): + async for res in ctx.posedge(a).delay(1.5): + log.append(res) + + async def testbench(ctx): + await ctx.delay(0.5) + ctx.set(a, 1) + await ctx.delay(0.5) + ctx.set(a, 0) + await ctx.delay(0.5) + ctx.set(a, 1) + await ctx.delay(1) + ctx.set(a, 0) + await ctx.delay(1) + ctx.set(a, 1) + + sim = Simulator(Module()) + sim.add_process(monitor) + sim.add_testbench(testbench) + sim.run() + + self.assertEqual(log, [ + (True, False), + (True, False), + (False, True), + (True, False), + ]) + + def test_struct(self): + class MyStruct(data.Struct): + x: unsigned(4) + y: signed(4) + + a = Signal(MyStruct) + b = Signal(MyStruct) + + m = Module() + m.domains.sync = ClockDomain() + + log = [] + + async def adder(ctx): + async for av, in ctx.changed(a): + ctx.set(b, { + "x": av.y, + "y": av.x + }) + + async def monitor(ctx): + async for _, _, bv in ctx.tick().sample(b): + log.append(bv) + + async def testbench(ctx): + ctx.set(a.x, 1) + ctx.set(a.y, 2) + self.assertEqual(ctx.get(b.x), 2) + self.assertEqual(ctx.get(b.y), 1) + self.assertEqual(ctx.get(b), MyStruct.const({"x": 2, "y": 1})) + await ctx.tick() + ctx.set(a, MyStruct.const({"x": 3, "y": 4})) + await ctx.tick() + + sim = Simulator(m) + sim.add_process(adder) + sim.add_process(monitor) + sim.add_testbench(testbench) + sim.add_clock(1e-6) + sim.run() + + self.assertEqual(log, [ + MyStruct.const({"x": 2, "y": 1}), + MyStruct.const({"x": 4, "y": 3}), + ]) + + def test_valuecastable(self): + a = Signal(4) + b = Signal(4) + t = Signal() + idx = Signal() + arr = Array([a, b]) + + async def process(ctx): + async for _ in ctx.posedge(t): + ctx.set(arr[idx], 1) + + async def testbench(ctx): + self.assertEqual(ctx.get(arr[idx]), 0) + ctx.set(t, 1) + self.assertEqual(ctx.get(a), 1) + ctx.set(idx, 1) + ctx.set(arr[idx], 2) + self.assertEqual(ctx.get(b), 2) + + sim = Simulator(Module()) + sim.add_process(process) + sim.add_testbench(testbench) + sim.run() + + def test_tick_repeat_until(self): + ctr = Signal(4) + m = Module() + m.domains.sync = cd_sync = ClockDomain() + m.d.sync += ctr.eq(ctr + 1) + + async def testbench(ctx): + _, _, val, = await ctx.tick(cd_sync).sample(ctr) + self.assertEqual(val, 0) + self.assertEqual(ctx.get(ctr), 1) + val, = await ctx.tick(cd_sync).sample(ctr).until(ctr == 4) + self.assertEqual(val, 4) + self.assertEqual(ctx.get(ctr), 5) + val, = await ctx.tick(cd_sync).sample(ctr).repeat(3) + self.assertEqual(val, 7) + self.assertEqual(ctx.get(ctr), 8) + + sim = Simulator(m) + sim.add_testbench(testbench) + sim.add_clock(1e-6) + sim.run() + + def test_critical(self): + ctr = Signal(4) + m = Module() + m.domains.sync = cd_sync = ClockDomain() + m.d.sync += ctr.eq(ctr + 1) + + last_ctr = 0 + + async def testbench(ctx): + await ctx.tick().repeat(7) + + async def bgbench(ctx): + nonlocal last_ctr + while True: + await ctx.tick() + with ctx.critical(): + await ctx.tick().repeat(2) + last_ctr = ctx.get(ctr) + + sim = Simulator(m) + sim.add_testbench(testbench) + sim.add_testbench(bgbench, background=True) + sim.add_clock(1e-6) + sim.run() + + self.assertEqual(last_ctr, 9) + + def test_async_reset(self): + ctr = Signal(4) + m = Module() + m.domains.sync = cd_sync = ClockDomain(async_reset=True) + m.d.sync += ctr.eq(ctr + 1) + + log = [] + + async def monitor(ctx): + async for res in ctx.tick().sample(ctr): + log.append(res) + + async def testbench(ctx): + await ctx.posedge(cd_sync.clk) + await ctx.posedge(cd_sync.clk) + await ctx.negedge(cd_sync.clk) + ctx.set(cd_sync.rst, True) + await ctx.negedge(cd_sync.clk) + ctx.set(cd_sync.rst, False) + await ctx.posedge(cd_sync.clk) + await ctx.posedge(cd_sync.clk) + + async def repeat_bench(ctx): + with self.assertRaises(DomainReset): + await ctx.tick().repeat(4) + + async def until_bench(ctx): + with self.assertRaises(DomainReset): + await ctx.tick().until(ctr == 3) + + sim = Simulator(m) + sim.add_process(monitor) + sim.add_testbench(testbench) + sim.add_testbench(repeat_bench) + sim.add_testbench(until_bench) + sim.add_clock(1e-6) + sim.run() + + self.assertEqual(log, [ + (True, False, 0), + (True, False, 1), + (False, True, 2), + (True, True, 0), + (True, False, 0), + (True, False, 1), + ]) + + def test_sync_reset(self): + ctr = Signal(4) + m = Module() + m.domains.sync = cd_sync = ClockDomain() + m.d.sync += ctr.eq(ctr + 1) + + log = [] + + async def monitor(ctx): + async for res in ctx.tick().sample(ctr): + log.append(res) + + async def testbench(ctx): + await ctx.posedge(cd_sync.clk) + await ctx.posedge(cd_sync.clk) + await ctx.negedge(cd_sync.clk) + ctx.set(cd_sync.rst, True) + await ctx.negedge(cd_sync.clk) + ctx.set(cd_sync.rst, False) + await ctx.posedge(cd_sync.clk) + await ctx.posedge(cd_sync.clk) + + sim = Simulator(m) + sim.add_process(monitor) + sim.add_testbench(testbench) + sim.add_clock(1e-6) + sim.run() + + self.assertEqual(log, [ + (True, False, 0), + (True, False, 1), + (True, True, 2), + (True, False, 0), + (True, False, 1), + ]) + + def test_broken_multiedge(self): + a = Signal() + + broken_trigger_hit = False + + async def testbench(ctx): + await ctx.delay(1) + ctx.set(a, 1) + ctx.set(a, 0) + ctx.set(a, 1) + ctx.set(a, 0) + await ctx.delay(1) + + async def monitor(ctx): + nonlocal broken_trigger_hit + try: + async for _ in ctx.edge(a, 1): + pass + except BrokenTrigger: + broken_trigger_hit = True + + sim = Simulator(Module()) + sim.add_testbench(testbench) + sim.add_testbench(monitor, background=True) + sim.run() + + self.assertTrue(broken_trigger_hit) + + def test_broken_other_trigger(self): + m = Module() + m.domains.sync = ClockDomain() + + async def testbench(ctx): + with self.assertRaises(BrokenTrigger): + async for _ in ctx.tick(): + await ctx.delay(2) + + sim = Simulator(m) + sim.add_testbench(testbench) + sim.add_clock(1) + sim.run() + + def test_abandon_delay(self): + ctr = Signal(4) + m = Module() + m.domains.sync = ClockDomain() + m.d.sync += ctr.eq(ctr + 1) + + async def testbench(ctx): + async for _ in ctx.delay(1).delay(1): + break + + await ctx.tick() + await ctx.tick() + self.assertEqual(ctx.get(ctr), 2) + + sim = Simulator(m) + sim.add_testbench(testbench) + sim.add_clock(4) + sim.run() + + def test_abandon_changed(self): + ctr = Signal(4) + a = Signal() + m = Module() + m.domains.sync = ClockDomain() + m.d.sync += ctr.eq(ctr + 1) + + async def testbench(ctx): + async for _ in ctx.changed(a): + break + + await ctx.tick() + await ctx.tick() + self.assertEqual(ctx.get(ctr), 2) + + async def change(ctx): + await ctx.delay(1) + ctx.set(a, 1) + await ctx.delay(1) + ctx.set(a, 0) + await ctx.delay(1) + ctx.set(a, 1) + + sim = Simulator(m) + sim.add_testbench(testbench) + sim.add_testbench(change) + sim.add_clock(4) + sim.run() + + def test_trigger_wrong(self): + a = Signal(4) + m = Module() + m.domains.sync = cd_sync = ClockDomain() + + reached_tb = False + reached_proc = False + + async def process(ctx): + nonlocal reached_proc + with self.assertRaisesRegex(TypeError, + r"^`\.get\(\)` cannot be used to sample values in simulator processes; " + r"use `\.sample\(\)` on a trigger object instead$"): + ctx.get(a) + reached_proc = True + + async def testbench(ctx): + nonlocal reached_tb + with self.assertRaisesRegex(TypeError, + r"^Change trigger can only be used with a signal, not \(~ \(sig a\)\)$"): + await ctx.changed(~a) + with self.assertRaisesRegex(TypeError, + r"^Edge trigger can only be used with a single-bit signal or " + r"a single-bit slice of a signal, not \(sig a\)$"): + await ctx.posedge(a) + with self.assertRaisesRegex(ValueError, + r"^Edge trigger polarity must be 0 or 1, not 2$"): + await ctx.edge(a[0], 2) + with self.assertRaisesRegex(TypeError, + r"^Condition must be a value-like object, not 'meow'$"): + await ctx.tick().until("meow") + with self.assertRaisesRegex(ValueError, + r"^Repeat count must be a positive integer, not 0$"): + await ctx.tick().repeat(0) + with self.assertRaisesRegex(ValueError, + r"^Combinational domain does not have a clock$"): + await ctx.tick("comb") + with self.assertRaisesRegex(ValueError, + r"^Context cannot be provided if a clock domain is specified directly$"): + await ctx.tick(cd_sync, context=m) + reached_tb = True + + sim = Simulator(m) + sim.add_process(process) + sim.add_testbench(testbench) + sim.run() + + self.assertTrue(reached_tb) + self.assertTrue(reached_proc)