From 8dc27d8b9371c9a7e3a5311a50c9aaca1a526324 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Fri, 13 Jan 2023 19:42:30 +0000 Subject: [PATCH 01/11] feat: hierarchical flattening draft --- hdl21/flatten.py | 113 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 hdl21/flatten.py diff --git a/hdl21/flatten.py b/hdl21/flatten.py new file mode 100644 index 00000000..dce059f8 --- /dev/null +++ b/hdl21/flatten.py @@ -0,0 +1,113 @@ +import copy +import random +from dataclasses import dataclass, field, replace +from enum import Enum, auto +from typing import Any, Generator + +import tree +from loguru import logger + +import hdl21 as h +from hdl21.primitives import MosType + + +@dataclass +class PrimNode: + inst: h.Instance + path: tuple[h.Module, ...] = tuple() + conns: dict[str, h.Signal] = field(default_factory=dict) # type: ignore + + def __str__(self): + name = self.make_name() + path = [p.name for p in self.path] + conns = flat_conns(self.conns) + return str(dict(name=name, path=path, conns=conns)) + + def make_name(self): + return ":".join([p.name or "_" for p in self.path]) + + +def walk(m: h.Module, parents=tuple(), conns=dict()) -> Generator[PrimNode, None, None]: + if not conns: + conns = {**m.signals, **m.ports} + for inst in m.instances.values(): + logger.debug(f"walk: {m.name} / {inst.name}") + new_conns = {} + new_parents = parents + (inst,) + for src_port_name, sig in inst.conns.items(): + match sig: + case h.Signal(): + key = sig.name + case h.PortRef(): + key = sig.portname + case _: + raise ValueError(f"unexpected signal type: {type(sig)}") + new_sig_name = ":".join([p.name for p in parents] + [key]) + if key in conns: + target_sig = conns[key] + elif key in m.signals: + target_sig = replace(m.signals[key], name=new_sig_name) + elif key in m.ports: + target_sig = replace(m.ports[key], name=new_sig_name) + else: + raise ValueError(f"signal {key} not found") + new_conns[src_port_name] = target_sig + + if isinstance(inst.of, h.PrimitiveCall): + yield PrimNode(inst, new_parents, new_conns) + else: + yield from walk(inst.of, new_parents, new_conns) + + +def _find_signal(m: h.Module, name: str) -> h.Signal: + for port_name in m.ports: + if port_name == name: + return m.ports[port_name] + for sig_name in m.signals: + if sig_name == name: + return m.signals[sig_name] + raise ValueError(f"Signal {name} not found in module {m.name}") + + +def is_flat(m: h.Instance | h.Instantiable) -> bool: + if isinstance(m, h.Instance): + return is_flat(m.of) + elif isinstance(m, (h.PrimitiveCall, h.GeneratorCall, h.ExternalModuleCall)): + return True + elif isinstance(m, h.Module): + insts = m.instances.values() + return all(isinstance(inst.of, h.PrimitiveCall) for inst in insts) + else: + raise ValueError(f"Unexpected type {type(m)}") + + +def flat_conns(conns): + flat = tree.flatten_with_path(conns) + return {":".join(reversed(path)): value.name for path, value in flat} + + +def flat_module(m: h.Module): + m = h.elaborate(m) + + nodes = list(walk(m)) + for n in nodes: + logger.debug(str(n)) + + new_module = h.Module((m.name or "module") + "_flat") + for port_name in m.ports: + new_module.add(h.Port(name=port_name)) + + for n in nodes: + for sig in n.conns.values(): + sig_name = sig.name + if sig_name not in new_module.ports: + new_module.add(h.Signal(name=sig_name)) + + for n in nodes: + new_inst = new_module.add(n.inst.of(), name=n.make_name()) + + for src_port_name, sig in n.conns.items(): + matching_sig = _find_signal(new_module, sig.name) + new_inst.connect(src_port_name, matching_sig) + + return new_module From ad6212c64273ef1a7279b4139492e06dbf0b6967 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Fri, 13 Jan 2023 19:56:50 +0000 Subject: [PATCH 02/11] refactor: remove `tree` dependency and clean up logging --- hdl21/flatten.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/hdl21/flatten.py b/hdl21/flatten.py index dce059f8..0e9818f1 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -1,14 +1,7 @@ -import copy -import random from dataclasses import dataclass, field, replace -from enum import Enum, auto -from typing import Any, Generator - -import tree -from loguru import logger +from typing import Generator import hdl21 as h -from hdl21.primitives import MosType @dataclass @@ -31,7 +24,7 @@ def walk(m: h.Module, parents=tuple(), conns=dict()) -> Generator[PrimNode, None if not conns: conns = {**m.signals, **m.ports} for inst in m.instances.values(): - logger.debug(f"walk: {m.name} / {inst.name}") + # logger.debug(f"walk: {m.name} / {inst.name}") new_conns = {} new_parents = parents + (inst,) for src_port_name, sig in inst.conns.items(): @@ -81,28 +74,38 @@ def is_flat(m: h.Instance | h.Instantiable) -> bool: raise ValueError(f"Unexpected type {type(m)}") +def _walk_conns(conns, parents=tuple()): + for key, val in conns.items(): + if isinstance(val, dict): + yield from _walk_conns(val, parents + (key,)) + else: + yield parents + (key,), val + + def flat_conns(conns): - flat = tree.flatten_with_path(conns) - return {":".join(reversed(path)): value.name for path, value in flat} + return {":".join(reversed(path)): value.name for path, value in _walk_conns(conns)} def flat_module(m: h.Module): m = h.elaborate(m) + # recursively walk the module and collect all primitive instances nodes = list(walk(m)) - for n in nodes: - logger.debug(str(n)) + # for n in nodes: + # logger.debug(str(n)) new_module = h.Module((m.name or "module") + "_flat") for port_name in m.ports: new_module.add(h.Port(name=port_name)) + # add all signals to the root level for n in nodes: for sig in n.conns.values(): sig_name = sig.name if sig_name not in new_module.ports: new_module.add(h.Signal(name=sig_name)) + # add all connections to the root level with names resolved for n in nodes: new_inst = new_module.add(n.inst.of(), name=n.make_name()) From d7dbc24e896794d336d0c95f50106f75c3a5af31 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Mon, 23 Jan 2023 18:16:39 +0000 Subject: [PATCH 03/11] fix: type hint for older Python version --- hdl21/flatten.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hdl21/flatten.py b/hdl21/flatten.py index 0e9818f1..5bc1db71 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, replace -from typing import Generator +from typing import Generator, Union import hdl21 as h @@ -62,7 +62,7 @@ def _find_signal(m: h.Module, name: str) -> h.Signal: raise ValueError(f"Signal {name} not found in module {m.name}") -def is_flat(m: h.Instance | h.Instantiable) -> bool: +def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool: if isinstance(m, h.Instance): return is_flat(m.of) elif isinstance(m, (h.PrimitiveCall, h.GeneratorCall, h.ExternalModuleCall)): From e26f14b97ff1df00dcc1e219269468089f1dfdb5 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Tue, 24 Jan 2023 18:32:06 +0000 Subject: [PATCH 04/11] feat: update syntax and add tests --- hdl21/flatten.py | 45 ++++++++------- hdl21/tests/test_flatten.py | 106 ++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 20 deletions(-) create mode 100644 hdl21/tests/test_flatten.py diff --git a/hdl21/flatten.py b/hdl21/flatten.py index 5bc1db71..0535514e 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -1,26 +1,30 @@ from dataclasses import dataclass, field, replace -from typing import Generator, Union +from typing import Generator, Union, Tuple, Dict import hdl21 as h @dataclass -class PrimNode: +class _PrimitiveNode: inst: h.Instance - path: tuple[h.Module, ...] = tuple() - conns: dict[str, h.Signal] = field(default_factory=dict) # type: ignore + path: Tuple[h.Module, ...] = tuple() + conns: Dict[str, h.Signal] = field(default_factory=dict) # type: ignore def __str__(self): name = self.make_name() path = [p.name for p in self.path] - conns = flat_conns(self.conns) + conns = _flat_conns(self.conns) return str(dict(name=name, path=path, conns=conns)) def make_name(self): return ":".join([p.name or "_" for p in self.path]) -def walk(m: h.Module, parents=tuple(), conns=dict()) -> Generator[PrimNode, None, None]: +def _walk( + m: h.Module, + parents=tuple(), + conns=None, +) -> Generator[_PrimitiveNode, None, None]: if not conns: conns = {**m.signals, **m.ports} for inst in m.instances.values(): @@ -28,13 +32,13 @@ def walk(m: h.Module, parents=tuple(), conns=dict()) -> Generator[PrimNode, None new_conns = {} new_parents = parents + (inst,) for src_port_name, sig in inst.conns.items(): - match sig: - case h.Signal(): - key = sig.name - case h.PortRef(): - key = sig.portname - case _: - raise ValueError(f"unexpected signal type: {type(sig)}") + if isinstance(sig, h.Signal): + key = sig.name + elif isinstance(sig, h.PortRef): + key = sig.portname + else: + raise ValueError(f"unexpected signal type: {type(sig)}") + new_sig_name = ":".join([p.name for p in parents] + [key]) if key in conns: target_sig = conns[key] @@ -47,9 +51,9 @@ def walk(m: h.Module, parents=tuple(), conns=dict()) -> Generator[PrimNode, None new_conns[src_port_name] = target_sig if isinstance(inst.of, h.PrimitiveCall): - yield PrimNode(inst, new_parents, new_conns) + yield _PrimitiveNode(inst, new_parents, new_conns) else: - yield from walk(inst.of, new_parents, new_conns) + yield from _walk(inst.of, new_parents, new_conns) def _find_signal(m: h.Module, name: str) -> h.Signal: @@ -82,18 +86,19 @@ def _walk_conns(conns, parents=tuple()): yield parents + (key,), val -def flat_conns(conns): +def _flat_conns(conns): return {":".join(reversed(path)): value.name for path, value in _walk_conns(conns)} -def flat_module(m: h.Module): +def flatten(m: h.Module) -> h.Module: m = h.elaborate(m) + if is_flat(m): + return m # recursively walk the module and collect all primitive instances - nodes = list(walk(m)) - # for n in nodes: - # logger.debug(str(n)) + nodes = list(_walk(m)) + # NOTE: should we rename the module here? new_module = h.Module((m.name or "module") + "_flat") for port_name in m.ports: new_module.add(h.Port(name=port_name)) diff --git a/hdl21/tests/test_flatten.py b/hdl21/tests/test_flatten.py new file mode 100644 index 00000000..16ea197e --- /dev/null +++ b/hdl21/tests/test_flatten.py @@ -0,0 +1,106 @@ +from dataclasses import replace +from enum import Enum +from typing import Any + +import pytest + +import hdl21 as h +from hdl21.flatten import flatten, is_flat + + +@h.module +class Inverter: + vdd, vss, vin, vout = h.Ports(4) + + pmos = h.Pmos()(d=vout, g=vin, s=vdd, b=vdd) + nmos = h.Nmos()(d=vout, g=vin, s=vss, b=vss) + + +@h.module +class Buffer: + vdd, vss, vin, vout = h.Ports(4) + inv_1 = Inverter(vdd=vdd, vss=vss, vin=vin) # type: ignore + inv_2 = Inverter(vdd=vdd, vss=vss, vin=inv_1.vout, vout=vout) # type: ignore + + +@h.module +class DoubleBuffer: + vdd, vss, vin, vout = h.Ports(4) + buffer_1 = Buffer(vdd=vdd, vss=vss, vin=vin) + buffer_2 = Buffer(vdd=vdd, vss=vss, vin=buffer_1.vout, vout=vout) + + +@h.module +class InvBuffer: + vdd, vss, vin, vout = h.Ports(4) + inv = Inverter(vdd=vdd, vss=vss, vin=vin) + buffer = Buffer(vdd=vdd, vss=vss, vin=inv.vout, vout=vout) + + +def test_is_flat(): + assert is_flat(Inverter) + assert not is_flat(Buffer) + assert not is_flat(InvBuffer) + + +def test_flatten_inv(): + inv = Inverter + assert is_flat(inv) + inv_raw_proto = h.to_proto(h.elaborate(inv)) + inv_flatten_proto = h.to_proto(flatten(inv)) + assert inv_raw_proto == inv_flatten_proto + + +def test_flatten_buffer(): + buffer = Buffer + assert not is_flat(buffer) + buffer_flat = flatten(buffer) + assert buffer_flat.instances.keys() == { + "inv_1:pmos", + "inv_1:nmos", + "inv_2:pmos", + "inv_2:nmos", + } + assert buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"} + assert buffer_flat.signals.keys() == {"inv_1_vout"} + assert is_flat(buffer_flat) + + +def test_flatten_double_buffer(): + double_buffer = DoubleBuffer + assert not is_flat(double_buffer) + double_buffer_flat = flatten(double_buffer) + assert double_buffer_flat.instances.keys() == { + "buffer_1:inv_1:pmos", + "buffer_1:inv_1:nmos", + "buffer_1:inv_2:pmos", + "buffer_1:inv_2:nmos", + "buffer_2:inv_1:pmos", + "buffer_2:inv_1:nmos", + "buffer_2:inv_2:pmos", + "buffer_2:inv_2:nmos", + } + assert double_buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"} + assert double_buffer_flat.signals.keys() == { + "buffer_1_vout", # connection between two buffers + "buffer_1:inv_1_vout", # internal signal between two inverters in buffer 1 + "buffer_2:inv_1_vout", # internal signal between two inverters in buffer 2 + } + assert is_flat(double_buffer_flat) + + +def test_flatten_inv_buffer(): + inv_buffer = InvBuffer + assert not is_flat(inv_buffer) + inv_buffer_flat = flatten(inv_buffer) + assert inv_buffer_flat.instances.keys() == { + "inv:pmos", + "inv:nmos", + "buffer:inv_1:pmos", + "buffer:inv_1:nmos", + "buffer:inv_2:pmos", + "buffer:inv_2:nmos", + } + assert inv_buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"} + assert inv_buffer_flat.signals.keys() == {"inv_vout", "buffer:inv_1_vout"} + assert is_flat(inv_buffer_flat) From c0d28e4e85733c10fe16210b67944cb732dcdc43 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Tue, 24 Jan 2023 22:36:34 +0000 Subject: [PATCH 05/11] test: improve coverage by testing node repr --- hdl21/flatten.py | 30 +++++++++++++++--------------- hdl21/tests/test_flatten.py | 10 +++++++++- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/hdl21/flatten.py b/hdl21/flatten.py index 0535514e..3757082a 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -4,6 +4,18 @@ import hdl21 as h +def _walk_conns(conns, parents=tuple()): + for key, val in conns.items(): + if isinstance(val, dict): + yield from _walk_conns(val, parents + (key,)) + else: + yield parents + (key,), val + + +def _flat_conns(conns): + return {":".join(reversed(path)): value.name for path, value in _walk_conns(conns)} + + @dataclass class _PrimitiveNode: inst: h.Instance @@ -20,7 +32,7 @@ def make_name(self): return ":".join([p.name or "_" for p in self.path]) -def _walk( +def walk( m: h.Module, parents=tuple(), conns=None, @@ -53,7 +65,7 @@ def _walk( if isinstance(inst.of, h.PrimitiveCall): yield _PrimitiveNode(inst, new_parents, new_conns) else: - yield from _walk(inst.of, new_parents, new_conns) + yield from walk(inst.of, new_parents, new_conns) def _find_signal(m: h.Module, name: str) -> h.Signal: @@ -78,25 +90,13 @@ def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool: raise ValueError(f"Unexpected type {type(m)}") -def _walk_conns(conns, parents=tuple()): - for key, val in conns.items(): - if isinstance(val, dict): - yield from _walk_conns(val, parents + (key,)) - else: - yield parents + (key,), val - - -def _flat_conns(conns): - return {":".join(reversed(path)): value.name for path, value in _walk_conns(conns)} - - def flatten(m: h.Module) -> h.Module: m = h.elaborate(m) if is_flat(m): return m # recursively walk the module and collect all primitive instances - nodes = list(_walk(m)) + nodes = list(walk(m)) # NOTE: should we rename the module here? new_module = h.Module((m.name or "module") + "_flat") diff --git a/hdl21/tests/test_flatten.py b/hdl21/tests/test_flatten.py index 16ea197e..4457a0d0 100644 --- a/hdl21/tests/test_flatten.py +++ b/hdl21/tests/test_flatten.py @@ -5,7 +5,7 @@ import pytest import hdl21 as h -from hdl21.flatten import flatten, is_flat +from hdl21.flatten import flatten, is_flat, walk @h.module @@ -104,3 +104,11 @@ def test_flatten_inv_buffer(): assert inv_buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"} assert inv_buffer_flat.signals.keys() == {"inv_vout", "buffer:inv_1_vout"} assert is_flat(inv_buffer_flat) + + +def test_flatten_node_desc(): + nodes = list(walk(Buffer)) + assert ( + str(nodes[0]) + == "{'name': 'inv_1:pmos', 'path': ['inv_1', 'pmos'], 'conns': {'d': 'inv_1_vout', 'g': 'vin', 's': 'vdd', 'b': 'vdd'}}" + ) From e55c2bf744a182cf47d47f1b4e541c884456eb89 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Thu, 25 May 2023 01:11:45 +0000 Subject: [PATCH 06/11] feat: updates flatten.py --- hdl21/flatten.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/hdl21/flatten.py b/hdl21/flatten.py index 3757082a..12989828 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -1,5 +1,9 @@ +""" +""" + +import copy from dataclasses import dataclass, field, replace -from typing import Generator, Union, Tuple, Dict +from typing import Dict, Generator, Tuple, Union import hdl21 as h @@ -40,14 +44,11 @@ def walk( if not conns: conns = {**m.signals, **m.ports} for inst in m.instances.values(): - # logger.debug(f"walk: {m.name} / {inst.name}") new_conns = {} new_parents = parents + (inst,) for src_port_name, sig in inst.conns.items(): if isinstance(sig, h.Signal): key = sig.name - elif isinstance(sig, h.PortRef): - key = sig.portname else: raise ValueError(f"unexpected signal type: {type(sig)}") @@ -68,14 +69,15 @@ def walk( yield from walk(inst.of, new_parents, new_conns) -def _find_signal(m: h.Module, name: str) -> h.Signal: - for port_name in m.ports: - if port_name == name: - return m.ports[port_name] - for sig_name in m.signals: - if sig_name == name: - return m.signals[sig_name] - raise ValueError(f"Signal {name} not found in module {m.name}") +def _find_signal_or_port(m: h.Module, name: str) -> h.Signal: + """Find a signal or port by name""" + + if (port := m.ports.get(name, None)) is not None: + return port + elif (sig := m.signals.get(name, None)) is not None: + return sig + else: + raise ValueError(f"Signal {name} not found in module {m.name}") def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool: @@ -85,7 +87,10 @@ def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool: return True elif isinstance(m, h.Module): insts = m.instances.values() - return all(isinstance(inst.of, h.PrimitiveCall) for inst in insts) + return all( + isinstance(inst.of, (h.PrimitiveCall, h.ExternalModuleCall)) + for inst in insts + ) else: raise ValueError(f"Unexpected type {type(m)}") @@ -100,22 +105,22 @@ def flatten(m: h.Module) -> h.Module: # NOTE: should we rename the module here? new_module = h.Module((m.name or "module") + "_flat") - for port_name in m.ports: - new_module.add(h.Port(name=port_name)) + for port in m.ports.values(): + new_module.add(copy.copy(port)) # add all signals to the root level for n in nodes: for sig in n.conns.values(): sig_name = sig.name if sig_name not in new_module.ports: - new_module.add(h.Signal(name=sig_name)) + new_module.add(copy.copy(sig)) # add all connections to the root level with names resolved for n in nodes: new_inst = new_module.add(n.inst.of(), name=n.make_name()) for src_port_name, sig in n.conns.items(): - matching_sig = _find_signal(new_module, sig.name) + matching_sig = _find_signal_or_port(new_module, sig.name) new_inst.connect(src_port_name, matching_sig) return new_module From 07d304e64bf0e751932d89d0b05402f224543644 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Thu, 25 May 2023 01:34:06 +0000 Subject: [PATCH 07/11] feat: update flatten.py for docs and format test_flatten.py --- hdl21/flatten.py | 36 ++++++++++++++++++++++++++++++++++++ hdl21/tests/test_flatten.py | 1 - 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/hdl21/flatten.py b/hdl21/flatten.py index 12989828..231d3129 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -96,6 +96,42 @@ def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool: def flatten(m: h.Module) -> h.Module: + r"""Flatten a module by moving all nested instances, ports and signals to the root level. + + For example, if we have a buffer module with two inverters, `inv_1` and `inv_2`, each with + two transistors, `pmos` and `nmos`, the module hierarchy looks like this: + All signals and ports will be moved to the root level too. + + ``` + buffer + ├── inv_1 + │ ├── pmos + │ └── nmos + └── inv_2 + ├── pmos + └── nmos + ``` + + This function will flatten the module to the following structure: + + ``` + buffer_flat + ├── inv_1:pmos + ├── inv_1:nmos + ├── inv_2:pmos + └── inv_2:nmos + ``` + + Nested signals and ports will be renamed and moved to the root level as well. For example, say a + module with two buffers, `buffer_1` and `buffer_2`, each with two inverters. On the top level, + the original ports `vdd`, `vss`, `vin` and `vout` are preserved. Nested signals and ports such + the connection between two buffers, the internal signal between two inverters in buffer 1 will + be renamed and moved to the root level as `buffer_1_vout`, `buffer_1:inv_1_vout`, + `buffer_2:inv_1_vout`. + + See tests/test_flatten.py for more examples. + """ + m = h.elaborate(m) if is_flat(m): return m diff --git a/hdl21/tests/test_flatten.py b/hdl21/tests/test_flatten.py index 4457a0d0..585ea2bf 100644 --- a/hdl21/tests/test_flatten.py +++ b/hdl21/tests/test_flatten.py @@ -11,7 +11,6 @@ @h.module class Inverter: vdd, vss, vin, vout = h.Ports(4) - pmos = h.Pmos()(d=vout, g=vin, s=vdd, b=vdd) nmos = h.Nmos()(d=vout, g=vin, s=vss, b=vss) From b0df06f7522625bd35c8aecd2a208c8a439d1392 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Thu, 25 May 2023 01:37:07 +0000 Subject: [PATCH 08/11] feat: update docs for flatten.py --- hdl21/flatten.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/hdl21/flatten.py b/hdl21/flatten.py index 231d3129..f61478b8 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -1,4 +1,7 @@ """ +Flatten a module by moving all nested instances, ports and signals to the root level. + +See function `flatten()` for details. """ import copy @@ -97,7 +100,7 @@ def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool: def flatten(m: h.Module) -> h.Module: r"""Flatten a module by moving all nested instances, ports and signals to the root level. - + For example, if we have a buffer module with two inverters, `inv_1` and `inv_2`, each with two transistors, `pmos` and `nmos`, the module hierarchy looks like this: All signals and ports will be moved to the root level too. @@ -111,9 +114,9 @@ def flatten(m: h.Module) -> h.Module: ├── pmos └── nmos ``` - + This function will flatten the module to the following structure: - + ``` buffer_flat ├── inv_1:pmos @@ -121,17 +124,17 @@ def flatten(m: h.Module) -> h.Module: ├── inv_2:pmos └── inv_2:nmos ``` - + Nested signals and ports will be renamed and moved to the root level as well. For example, say a module with two buffers, `buffer_1` and `buffer_2`, each with two inverters. On the top level, the original ports `vdd`, `vss`, `vin` and `vout` are preserved. Nested signals and ports such the connection between two buffers, the internal signal between two inverters in buffer 1 will be renamed and moved to the root level as `buffer_1_vout`, `buffer_1:inv_1_vout`, `buffer_2:inv_1_vout`. - + See tests/test_flatten.py for more examples. """ - + m = h.elaborate(m) if is_flat(m): return m From 6592955ec2f87a5e73e8de5d3404037756bcf0dd Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Thu, 25 May 2023 23:29:30 +0000 Subject: [PATCH 09/11] fix: python 3.7 syntax --- hdl21/flatten.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/hdl21/flatten.py b/hdl21/flatten.py index f61478b8..56098107 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -75,12 +75,15 @@ def walk( def _find_signal_or_port(m: h.Module, name: str) -> h.Signal: """Find a signal or port by name""" - if (port := m.ports.get(name, None)) is not None: + port = m.ports.get(name, None) + if port is not None: return port - elif (sig := m.signals.get(name, None)) is not None: + + sig = m.signals.get(name, None) + if sig is not None: return sig - else: - raise ValueError(f"Signal {name} not found in module {m.name}") + + raise ValueError(f"Signal {name} not found in module {m.name}") def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool: From 4cb4f0070dc4ac4e85e599be45f89f000f0a5723 Mon Sep 17 00:00:00 2001 From: Dan Fritchman Date: Wed, 31 May 2023 23:21:29 +0000 Subject: [PATCH 10/11] Add flattening tests for Generators, Slices, Concats --- hdl21/tests/test_flatten.py | 69 +++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/hdl21/tests/test_flatten.py b/hdl21/tests/test_flatten.py index 585ea2bf..d07f9f84 100644 --- a/hdl21/tests/test_flatten.py +++ b/hdl21/tests/test_flatten.py @@ -111,3 +111,72 @@ def test_flatten_node_desc(): str(nodes[0]) == "{'name': 'inv_1:pmos', 'path': ['inv_1', 'pmos'], 'conns': {'d': 'inv_1_vout', 'g': 'vin', 's': 'vdd', 'b': 'vdd'}}" ) + + +def test_flatten_with_slices(): + """Flatten a Module with slices""" + + @h.module + class InvTwoPack: + # A two-bit-bus's worth of inverters + vdd, vss = h.Ports(2) + vin, vout = 2 * h.Port(width=2) + inv_0 = Inverter(vdd=vdd, vss=vss, vin=vin[0], vout=vout[0]) # type: ignore + inv_1 = Inverter(vdd=vdd, vss=vss, vin=vin[1], vout=vout[1]) # type: ignore + + flattened = flatten(InvTwoPack) + + +def test_flatten_with_concat(): + """Test flattening a module with signal concatenations.""" + + @h.module + class NmosArray: + # An instance array of Nmos'es. Maybe for a current DAC? + g, vss = 2 * h.Port() + d = h.Port(width=8) + nmoses = 8 * h.Nmos()(g=g, d=d, s=vss, b=vss) + + @h.module + class M: + # A thing that instantiates an NmosArray, and concatenates some signals together to form its drain connections. + g, vss = 2 * h.Port() + s4 = h.Signal(width=4) + s2 = h.Signal(width=2) + s1 = h.Signal(width=1) + s0 = h.Signal(width=1) + nmos_array = NmosArray(d=h.Concat(s4, s2, s1, s0), g=g, vss=vss) + + flattened = flatten(M) + + +def test_flatten_generator(): + """Flatten a Generator + (And it has Concat too! Kinda. But not by the point flattening happens).""" + + @h.paramclass + class Params: + how_many = h.Param(dtype=int, default=5, desc="How many inverters you want?") + + @h.generator + def Gen(params: Params) -> h.Module: + if params.how_many < 2: + raise ValueError + + @h.module + class M: + vdd, vss, vin, vout = h.Ports(4) + internal_signal = h.Signal(width=params.how_many - 1) + # Array of `how_many` series-connected inverters + invs = params.how_many * Inverter( + vdd=vdd, + vss=vss, + vin=h.Concat(vin, internal_signal), + vout=h.Concat(internal_signal, vout), + ) + + return M + + gen5 = Gen(how_many=5) + assert not is_flat(gen5) + flattened = flatten(gen5) From 094d6c55de7d9d2333ed3f50e58877c200f4a097 Mon Sep 17 00:00:00 2001 From: Zeyi Wang Date: Thu, 23 Nov 2023 20:21:32 -0500 Subject: [PATCH 11/11] remove GeneratorCall test --- hdl21/flatten.py | 2 +- hdl21/tests/test_flatten.py | 31 ------------------------------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/hdl21/flatten.py b/hdl21/flatten.py index 56098107..ea787d5e 100644 --- a/hdl21/flatten.py +++ b/hdl21/flatten.py @@ -89,7 +89,7 @@ def _find_signal_or_port(m: h.Module, name: str) -> h.Signal: def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool: if isinstance(m, h.Instance): return is_flat(m.of) - elif isinstance(m, (h.PrimitiveCall, h.GeneratorCall, h.ExternalModuleCall)): + elif isinstance(m, (h.PrimitiveCall, h.ExternalModuleCall)): return True elif isinstance(m, h.Module): insts = m.instances.values() diff --git a/hdl21/tests/test_flatten.py b/hdl21/tests/test_flatten.py index d07f9f84..913354f2 100644 --- a/hdl21/tests/test_flatten.py +++ b/hdl21/tests/test_flatten.py @@ -149,34 +149,3 @@ class M: flattened = flatten(M) - -def test_flatten_generator(): - """Flatten a Generator - (And it has Concat too! Kinda. But not by the point flattening happens).""" - - @h.paramclass - class Params: - how_many = h.Param(dtype=int, default=5, desc="How many inverters you want?") - - @h.generator - def Gen(params: Params) -> h.Module: - if params.how_many < 2: - raise ValueError - - @h.module - class M: - vdd, vss, vin, vout = h.Ports(4) - internal_signal = h.Signal(width=params.how_many - 1) - # Array of `how_many` series-connected inverters - invs = params.how_many * Inverter( - vdd=vdd, - vss=vss, - vin=h.Concat(vin, internal_signal), - vout=h.Concat(internal_signal, vout), - ) - - return M - - gen5 = Gen(how_many=5) - assert not is_flat(gen5) - flattened = flatten(gen5)