Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: hierarchical flattening #90

Closed
wants to merge 11 commits into from
168 changes: 168 additions & 0 deletions hdl21/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""
Flatten a module by moving all nested instances, ports and signals to the root level.

See function `flatten()` for details.
"""

import copy
from dataclasses import dataclass, field, replace
uduse marked this conversation as resolved.
Show resolved Hide resolved
from typing import Dict, Generator, Tuple, Union

import hdl21 as h


def _walk_conns(conns, parents=tuple()):
for key, val in conns.items():
if isinstance(val, dict):
yield from _walk_conns(val, parents + (key,))
else:
yield parents + (key,), val


def _flat_conns(conns):
return {":".join(reversed(path)): value.name for path, value in _walk_conns(conns)}


@dataclass
class _PrimitiveNode:
inst: h.Instance
path: Tuple[h.Module, ...] = tuple()
conns: Dict[str, h.Signal] = field(default_factory=dict) # type: ignore

def __str__(self):
name = self.make_name()
path = [p.name for p in self.path]
conns = _flat_conns(self.conns)
return str(dict(name=name, path=path, conns=conns))

def make_name(self):
return ":".join([p.name or "_" for p in self.path])


def walk(
m: h.Module,
parents=tuple(),
conns=None,
) -> Generator[_PrimitiveNode, None, None]:
if not conns:
conns = {**m.signals, **m.ports}
for inst in m.instances.values():
new_conns = {}
new_parents = parents + (inst,)
for src_port_name, sig in inst.conns.items():
if isinstance(sig, h.Signal):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things that will be in inst.conns.values, in addition to Signal, will include: Slice, Concat.

There is a Connectable union of everything that can go into conns before elaboration. After, that should be reduced to:

Connectable = Union[
    # Still gonna be there: 
    "Signal",
    "Slice",
    "Concat",
    # Removed during elaboration: 
    "NoConn",
    "PortRef",
    # And all this `Bundle` stuff is definitely removed via elaboration: 
    "BundleInstance",
    "AnonymousBundle",
    "BundleRef",
]

key = sig.name
else:
raise ValueError(f"unexpected signal type: {type(sig)}")

new_sig_name = ":".join([p.name for p in parents] + [key])
if key in conns:
target_sig = conns[key]
elif key in m.signals:
target_sig = replace(m.signals[key], name=new_sig_name)
elif key in m.ports:
target_sig = replace(m.ports[key], name=new_sig_name)
else:
raise ValueError(f"signal {key} not found")
new_conns[src_port_name] = target_sig

if isinstance(inst.of, h.PrimitiveCall):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs cases for:

  • GeneratorCall - which will need to be flattened, and
  • ExternalModuleCall - which we should regard as "already flat"

yield _PrimitiveNode(inst, new_parents, new_conns)
else:
yield from walk(inst.of, new_parents, new_conns)


def _find_signal_or_port(m: h.Module, name: str) -> h.Signal:
"""Find a signal or port by name"""

port = m.ports.get(name, None)
if port is not None:
return port

sig = m.signals.get(name, None)
if sig is not None:
return sig

raise ValueError(f"Signal {name} not found in module {m.name}")


def is_flat(m: Union[h.Instance, h.Instantiable]) -> bool:
if isinstance(m, h.Instance):
return is_flat(m.of)
uduse marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(m, (h.PrimitiveCall, h.ExternalModuleCall)):
return True
elif isinstance(m, h.Module):
insts = m.instances.values()
return all(
isinstance(inst.of, (h.PrimitiveCall, h.ExternalModuleCall))
for inst in insts
)
else:
raise ValueError(f"Unexpected type {type(m)}")


def flatten(m: h.Module) -> h.Module:
r"""Flatten a module by moving all nested instances, ports and signals to the root level.

For example, if we have a buffer module with two inverters, `inv_1` and `inv_2`, each with
two transistors, `pmos` and `nmos`, the module hierarchy looks like this:
All signals and ports will be moved to the root level too.

```
buffer
├── inv_1
│ ├── pmos
│ └── nmos
└── inv_2
├── pmos
└── nmos
```

This function will flatten the module to the following structure:

```
buffer_flat
├── inv_1:pmos
├── inv_1:nmos
├── inv_2:pmos
└── inv_2:nmos
```

Nested signals and ports will be renamed and moved to the root level as well. For example, say a
module with two buffers, `buffer_1` and `buffer_2`, each with two inverters. On the top level,
the original ports `vdd`, `vss`, `vin` and `vout` are preserved. Nested signals and ports such
the connection between two buffers, the internal signal between two inverters in buffer 1 will
be renamed and moved to the root level as `buffer_1_vout`, `buffer_1:inv_1_vout`,
`buffer_2:inv_1_vout`.

See tests/test_flatten.py for more examples.
"""

m = h.elaborate(m)
if is_flat(m):
return m

# recursively walk the module and collect all primitive instances
nodes = list(walk(m))

# NOTE: should we rename the module here?
new_module = h.Module((m.name or "module") + "_flat")
for port in m.ports.values():
new_module.add(copy.copy(port))

# add all signals to the root level
for n in nodes:
for sig in n.conns.values():
sig_name = sig.name
if sig_name not in new_module.ports:
new_module.add(copy.copy(sig))

# add all connections to the root level with names resolved
for n in nodes:
new_inst = new_module.add(n.inst.of(), name=n.make_name())

for src_port_name, sig in n.conns.items():
matching_sig = _find_signal_or_port(new_module, sig.name)
new_inst.connect(src_port_name, matching_sig)

return new_module
151 changes: 151 additions & 0 deletions hdl21/tests/test_flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from dataclasses import replace
from enum import Enum
from typing import Any

import pytest

import hdl21 as h
from hdl21.flatten import flatten, is_flat, walk


@h.module
class Inverter:
vdd, vss, vin, vout = h.Ports(4)
pmos = h.Pmos()(d=vout, g=vin, s=vdd, b=vdd)
nmos = h.Nmos()(d=vout, g=vin, s=vss, b=vss)


@h.module
class Buffer:
vdd, vss, vin, vout = h.Ports(4)
inv_1 = Inverter(vdd=vdd, vss=vss, vin=vin) # type: ignore
inv_2 = Inverter(vdd=vdd, vss=vss, vin=inv_1.vout, vout=vout) # type: ignore


@h.module
class DoubleBuffer:
vdd, vss, vin, vout = h.Ports(4)
buffer_1 = Buffer(vdd=vdd, vss=vss, vin=vin)
buffer_2 = Buffer(vdd=vdd, vss=vss, vin=buffer_1.vout, vout=vout)


@h.module
class InvBuffer:
vdd, vss, vin, vout = h.Ports(4)
inv = Inverter(vdd=vdd, vss=vss, vin=vin)
buffer = Buffer(vdd=vdd, vss=vss, vin=inv.vout, vout=vout)


def test_is_flat():
assert is_flat(Inverter)
assert not is_flat(Buffer)
assert not is_flat(InvBuffer)


def test_flatten_inv():
inv = Inverter
assert is_flat(inv)
inv_raw_proto = h.to_proto(h.elaborate(inv))
inv_flatten_proto = h.to_proto(flatten(inv))
assert inv_raw_proto == inv_flatten_proto


def test_flatten_buffer():
buffer = Buffer
assert not is_flat(buffer)
buffer_flat = flatten(buffer)
assert buffer_flat.instances.keys() == {
"inv_1:pmos",
"inv_1:nmos",
"inv_2:pmos",
"inv_2:nmos",
}
assert buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"}
assert buffer_flat.signals.keys() == {"inv_1_vout"}
assert is_flat(buffer_flat)


def test_flatten_double_buffer():
double_buffer = DoubleBuffer
assert not is_flat(double_buffer)
double_buffer_flat = flatten(double_buffer)
assert double_buffer_flat.instances.keys() == {
"buffer_1:inv_1:pmos",
"buffer_1:inv_1:nmos",
"buffer_1:inv_2:pmos",
"buffer_1:inv_2:nmos",
"buffer_2:inv_1:pmos",
"buffer_2:inv_1:nmos",
"buffer_2:inv_2:pmos",
"buffer_2:inv_2:nmos",
}
assert double_buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"}
assert double_buffer_flat.signals.keys() == {
"buffer_1_vout", # connection between two buffers
"buffer_1:inv_1_vout", # internal signal between two inverters in buffer 1
"buffer_2:inv_1_vout", # internal signal between two inverters in buffer 2
}
assert is_flat(double_buffer_flat)


def test_flatten_inv_buffer():
inv_buffer = InvBuffer
assert not is_flat(inv_buffer)
inv_buffer_flat = flatten(inv_buffer)
assert inv_buffer_flat.instances.keys() == {
"inv:pmos",
"inv:nmos",
"buffer:inv_1:pmos",
"buffer:inv_1:nmos",
"buffer:inv_2:pmos",
"buffer:inv_2:nmos",
}
assert inv_buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"}
assert inv_buffer_flat.signals.keys() == {"inv_vout", "buffer:inv_1_vout"}
assert is_flat(inv_buffer_flat)


def test_flatten_node_desc():
nodes = list(walk(Buffer))
assert (
str(nodes[0])
== "{'name': 'inv_1:pmos', 'path': ['inv_1', 'pmos'], 'conns': {'d': 'inv_1_vout', 'g': 'vin', 's': 'vdd', 'b': 'vdd'}}"
)


def test_flatten_with_slices():
"""Flatten a Module with slices"""

@h.module
class InvTwoPack:
# A two-bit-bus's worth of inverters
vdd, vss = h.Ports(2)
vin, vout = 2 * h.Port(width=2)
inv_0 = Inverter(vdd=vdd, vss=vss, vin=vin[0], vout=vout[0]) # type: ignore
inv_1 = Inverter(vdd=vdd, vss=vss, vin=vin[1], vout=vout[1]) # type: ignore

flattened = flatten(InvTwoPack)


def test_flatten_with_concat():
"""Test flattening a module with signal concatenations."""

@h.module
class NmosArray:
# An instance array of Nmos'es. Maybe for a current DAC?
g, vss = 2 * h.Port()
d = h.Port(width=8)
nmoses = 8 * h.Nmos()(g=g, d=d, s=vss, b=vss)

@h.module
class M:
# A thing that instantiates an NmosArray, and concatenates some signals together to form its drain connections.
g, vss = 2 * h.Port()
s4 = h.Signal(width=4)
s2 = h.Signal(width=2)
s1 = h.Signal(width=1)
s0 = h.Signal(width=1)
nmos_array = NmosArray(d=h.Concat(s4, s2, s1, s0), g=g, vss=vss)

flattened = flatten(M)

Loading