Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flattening Tweaks #213

Merged
merged 14 commits into from
Jan 5, 2024
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jobs:
- name: Upload coverage to Codecov # Adapted from https://github.com/codecov/codecov-action#usage
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
env_vars: OS,PYTHON
name: codecov-umbrella
Expand Down
198 changes: 198 additions & 0 deletions hdl21/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""
Flatten a module by moving all nested instances, ports and signals to the root level.

See function `flatten()` for details.
"""

import copy
from pydantic.dataclasses import dataclass
from dataclasses import field, replace
from typing import Dict, Generator, List, Optional

import hdl21 as h
from .signal import _copy_to_internal
from .datatype import AllowArbConfig


def _walk_conns(conns, parents=tuple()):
for key, val in conns.items():
if isinstance(val, dict):
yield from _walk_conns(val, parents + (key,))

Check warning on line 20 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L20

Added line #L20 was not covered by tests
else:
yield parents + (key,), val


def _flat_conns(conns):
return {":".join(reversed(path)): value.name for path, value in _walk_conns(conns)}


@dataclass(config=AllowArbConfig)
class FlattenedInstance:
"""
# Flattened Instance
An instance that survives flattening, because its contents are either (a) Primitive or (b) External.
"""

inst: h.Instance
path: List[h.Instance] = field(default_factory=list)
conns: Dict[str, h.Signal] = field(default_factory=dict)

def __post_init_post_parse__(self):
# Assert that this instance's target is either a primitive, or external
if not isinstance(self.inst.of, (h.PrimitiveCall, h.ExternalModuleCall)):
raise ValueError(f"Invalid flattened instance {self}")

Check warning on line 43 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L43

Added line #L43 was not covered by tests

def __str__(self):
name = self.make_name()
path = [p.name for p in self.path]
conns = _flat_conns(self.conns)
return str(dict(name=name, path=path, conns=conns))

def make_name(self):
return ":".join([p.name or "_" for p in self.path])


def walk(
m: h.Module,
parents: List[h.Instance],
conns: Optional[Dict[str, h.Signal]] = None,
) -> Generator[FlattenedInstance, None, None]:
if conns is None:
conns = {**m.signals, **m.ports}
for inst in m.instances.values():
new_conns = {}
new_parents = parents + [inst]
for src_port_name, sig in inst.conns.items():
if isinstance(sig, h.Signal):
key = sig.name
elif isinstance(sig, (h.Slice, h.Concat)):
msg = f"Flattening `Slice` and `Concat` is not (yet) supported"
raise NotImplementedError(msg)
elif isinstance(sig, (h.PortRef, h.BundleInstance, h.AnonymousBundle)):

Check warning on line 71 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L71

Added line #L71 was not covered by tests
# This shouldn't happen in normal use, but could in principle if
# someone e.g. calls this `walk` function directly.
msg = f"Error: {sig} should not have reached this stage in flattening"
raise RuntimeError(msg)

Check warning on line 75 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L74-L75

Added lines #L74 - L75 were not covered by tests
else:
raise TypeError(f"Invalid connection {sig}")

Check warning on line 77 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L77

Added line #L77 was not covered by tests

new_sig_name = ":".join([p.name for p in parents] + [key])
if key in conns:
target_sig = conns[key]
elif key in m.signals:
target_sig = replace(
_copy_to_internal(m.signals[key]), name=new_sig_name
)
elif key in m.ports:
target_sig = replace(_copy_to_internal(m.ports[key]), name=new_sig_name)

Check warning on line 87 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L86-L87

Added lines #L86 - L87 were not covered by tests
else:
raise ValueError(f"signal {key} not found")

Check warning on line 89 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L89

Added line #L89 was not covered by tests
new_conns[src_port_name] = target_sig

if isinstance(inst.of, h.PrimitiveCall):
yield FlattenedInstance(inst, new_parents, new_conns)
else:
yield from walk(inst.of, new_parents, new_conns)


def _find_signal_or_port(m: h.Module, name: str) -> h.Signal:
"""Find a signal or port by name"""

port = m.ports.get(name, None)
if port is not None:
return port

sig = m.signals.get(name, None)
if sig is not None:
return sig

raise ValueError(f"Signal {name} not found in module {m.name}")

Check warning on line 109 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L109

Added line #L109 was not covered by tests


def is_flat(m: h.Instantiable) -> bool:
"""Boolean indication of whether `Instantiable` `m` is already flat."""

if isinstance(m, (h.PrimitiveCall, h.ExternalModuleCall)):
return True

Check warning on line 116 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L116

Added line #L116 was not covered by tests
if isinstance(m, h.Module):
instancelike = (
list(m.instances.values())
+ list(m.instarrays.values())
+ list(m.instbundles.values())
)
return all(
isinstance(inst.of, (h.PrimitiveCall, h.ExternalModuleCall))
for inst in instancelike
)
raise TypeError(f"Invalid `Instantiable` argument to `is_flat`: {m}")

Check warning on line 127 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L127

Added line #L127 was not covered by tests


def flatten(m: h.Instantiable) -> h.Instantiable:
r"""Flatten a module by moving all nested instances, ports and signals to the root level.

For example, if we have a buffer module with two inverters, `inv_1` and `inv_2`, each with
two transistors, `pmos` and `nmos`, the module hierarchy looks like this:
All signals and ports will be moved to the root level too.

```
buffer
├── inv_1
│ ├── pmos
│ └── nmos
└── inv_2
├── pmos
└── nmos
```

This function will flatten the module to the following structure:

```
buffer_flat
├── inv_1:pmos
├── inv_1:nmos
├── inv_2:pmos
└── inv_2:nmos
```

Nested signals and ports will be renamed and moved to the root level as well. For example, say a
module with two buffers, `buffer_1` and `buffer_2`, each with two inverters. On the top level,
the original ports `vdd`, `vss`, `vin` and `vout` are preserved. Nested signals and ports such
the connection between two buffers, the internal signal between two inverters in buffer 1 will
be renamed and moved to the root level as `buffer_1_vout`, `buffer_1:inv_1_vout`,
`buffer_2:inv_1_vout`.

See tests/test_flatten.py for more examples.
"""

m = h.elaborate(m)
if is_flat(m):
return m

# recursively walk the module and collect all primitive instances
nodes: List[FlattenedInstance] = list(walk(m, parents=[]))

# Create our new, flattened Module
# Note that by virtue of going through elaboration above, `m.name` should be set.
# Check for it nonetheless, and raise an error if not.
if m.name is None:
raise ValueError(f"Anonymous Module {m} cannot be flattened. (Give it a name.)")

Check warning on line 178 in hdl21/flatten.py

View check run for this annotation

Codecov / codecov/patch

hdl21/flatten.py#L178

Added line #L178 was not covered by tests
new_module = h.Module(m.name + "_flat")
for port in m.ports.values():
new_module.add(copy.copy(port))

# add all signals to the root level
for n in nodes:
for sig in n.conns.values():
sig_name = sig.name
if sig_name not in new_module.ports:
new_module.add(copy.copy(sig))

# add all connections to the root level with names resolved
for n in nodes:
new_inst = new_module.add(n.inst.of(), name=n.make_name())

for src_port_name, sig in n.conns.items():
matching_sig = _find_signal_or_port(new_module, sig.name)
new_inst.connect(src_port_name, matching_sig)

return new_module
149 changes: 149 additions & 0 deletions hdl21/tests/test_flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import pytest

import hdl21 as h
from hdl21.flatten import flatten, is_flat, walk


@h.module
class Inverter:
vdd, vss, vin, vout = h.Ports(4)
pmos = h.Pmos()(d=vout, g=vin, s=vdd, b=vdd)
nmos = h.Nmos()(d=vout, g=vin, s=vss, b=vss)


@h.module
class Buffer:
vdd, vss, vin, vout = h.Ports(4)
inv_1 = Inverter(vdd=vdd, vss=vss, vin=vin) # type: ignore
inv_2 = Inverter(vdd=vdd, vss=vss, vin=inv_1.vout, vout=vout) # type: ignore


@h.module
class DoubleBuffer:
vdd, vss, vin, vout = h.Ports(4)
buffer_1 = Buffer(vdd=vdd, vss=vss, vin=vin)
buffer_2 = Buffer(vdd=vdd, vss=vss, vin=buffer_1.vout, vout=vout)


@h.module
class InvBuffer:
vdd, vss, vin, vout = h.Ports(4)
inv = Inverter(vdd=vdd, vss=vss, vin=vin)
buffer = Buffer(vdd=vdd, vss=vss, vin=inv.vout, vout=vout)


def test_is_flat():
assert is_flat(Inverter)
assert not is_flat(Buffer)
assert not is_flat(InvBuffer)


def test_flatten_inv():
inv = Inverter
assert is_flat(inv)
inv_raw_proto = h.to_proto(h.elaborate(inv))
inv_flatten_proto = h.to_proto(flatten(inv))
assert inv_raw_proto == inv_flatten_proto


def test_flatten_buffer():
buffer = Buffer
assert not is_flat(buffer)
buffer_flat = flatten(buffer)
assert buffer_flat.instances.keys() == {
"inv_1:pmos",
"inv_1:nmos",
"inv_2:pmos",
"inv_2:nmos",
}
assert buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"}
assert buffer_flat.signals.keys() == {"inv_1_vout"}
assert is_flat(buffer_flat)


def test_flatten_double_buffer():
double_buffer = DoubleBuffer
assert not is_flat(double_buffer)
double_buffer_flat = flatten(double_buffer)
assert double_buffer_flat.instances.keys() == {
"buffer_1:inv_1:pmos",
"buffer_1:inv_1:nmos",
"buffer_1:inv_2:pmos",
"buffer_1:inv_2:nmos",
"buffer_2:inv_1:pmos",
"buffer_2:inv_1:nmos",
"buffer_2:inv_2:pmos",
"buffer_2:inv_2:nmos",
}
assert double_buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"}
assert double_buffer_flat.signals.keys() == {
"buffer_1_vout", # connection between two buffers
"buffer_1:inv_1_vout", # internal signal between two inverters in buffer 1
"buffer_2:inv_1_vout", # internal signal between two inverters in buffer 2
}
assert is_flat(double_buffer_flat)


def test_flatten_inv_buffer():
inv_buffer = InvBuffer
assert not is_flat(inv_buffer)
inv_buffer_flat = flatten(inv_buffer)
assert inv_buffer_flat.instances.keys() == {
"inv:pmos",
"inv:nmos",
"buffer:inv_1:pmos",
"buffer:inv_1:nmos",
"buffer:inv_2:pmos",
"buffer:inv_2:nmos",
}
assert inv_buffer_flat.ports.keys() == {"vdd", "vss", "vin", "vout"}
assert inv_buffer_flat.signals.keys() == {"inv_vout", "buffer:inv_1_vout"}
assert is_flat(inv_buffer_flat)


def test_flatten_node_desc():
m = h.elaborate(Buffer)
nodes = list(walk(m, parents=[]))
assert (
str(nodes[0])
== "{'name': 'inv_1:pmos', 'path': ['inv_1', 'pmos'], 'conns': {'d': 'inv_1_vout', 'g': 'vin', 's': 'vdd', 'b': 'vdd'}}"
)


@pytest.mark.xfail(reason="FIXME: flatten with slices & concats")
def test_flatten_with_slices():
"""Flatten a Module with slices"""

@h.module
class InvTwoPack:
# A two-bit-bus's worth of inverters
vdd, vss = h.Ports(2)
vin, vout = 2 * h.Port(width=2)
inv_0 = Inverter(vdd=vdd, vss=vss, vin=vin[0], vout=vout[0]) # type: ignore
inv_1 = Inverter(vdd=vdd, vss=vss, vin=vin[1], vout=vout[1]) # type: ignore

flattened = flatten(InvTwoPack)


@pytest.mark.xfail(reason="FIXME: flatten with slices & concats")
def test_flatten_with_concat():
"""Test flattening a module with signal concatenations."""

@h.module
class NmosArray:
# An instance array of Nmos'es. Maybe for a current DAC?
g, vss = 2 * h.Port()
d = h.Port(width=8)
nmoses = 8 * h.Nmos()(g=g, d=d, s=vss, b=vss)

@h.module
class M:
# A thing that instantiates an NmosArray, and concatenates some signals together to form its drain connections.
g, vss = 2 * h.Port()
s4 = h.Signal(width=4)
s2 = h.Signal(width=2)
s1 = h.Signal(width=1)
s0 = h.Signal(width=1)
nmos_array = NmosArray(d=h.Concat(s4, s2, s1, s0), g=g, vss=vss)

flattened = flatten(M)
Loading