diff --git a/docs/explanation/system.md b/docs/explanation/system.md index ceddba3..283480d 100644 --- a/docs/explanation/system.md +++ b/docs/explanation/system.md @@ -58,3 +58,41 @@ Component.model_json_schema() - `infrasys` includes some basic quantities in [infrasys.quantities](#quantity-api). - Pint will automatically convert a list or list of lists of values into a `numpy.ndarray`. infrasys will handle serialization/de-serialization of these types. + + +### Component Associations +The system tracks associations between components in order to optimize lookups. + +For example, suppose a Generator class has a field for a Bus. It is trivial to find a generator's +bus. However, if you need to find all generators connected to specific bus, you would have to +traverse all generators in the system and check their bus values. + +Every time you add a component to a system, `infrasys` inspects the component type for composed +components. It checks for directly connected components, such as `Generator.bus`, and lists of +components. (It does not inspect other composite data structures like dictionaries.) + +`infrasys` stores these component associations in a SQLite table and so lookups are fast. + +Here is how to complete this example: + +```python +generators = system.list_parent_components(bus) +``` + +If you only want to find specific types, you can pass that type as well. +```python +generators = system.list_parent_components(bus, component_type=Generator) +``` + +**Warning**: There is one potentially problematic case. + +Suppose that you have a system with generators and buses and then reassign the buses, as in +``` +gen1.bus = other_bus +``` + +`infrasys` cannot detect such reassignments and so the component associations will be incorrect. +You must inform `infrasys` to rebuild its internal table. +``` +system.rebuild_component_associations() +``` diff --git a/src/infrasys/component_associations.py b/src/infrasys/component_associations.py new file mode 100644 index 0000000..c635ce6 --- /dev/null +++ b/src/infrasys/component_associations.py @@ -0,0 +1,99 @@ +import sqlite3 +from typing import Optional, Type +from uuid import UUID + +from loguru import logger + +from infrasys.component import Component +from infrasys.utils.sqlite import execute + + +class ComponentAssociations: + """Stores associations between components. Allows callers to quickly find components composed + by other components, such as the generator to which a bus is connected.""" + + TABLE_NAME = "component_associations" + + def __init__(self) -> None: + # This uses a different database because it is not persisted when the system + # is saved to files. It will be rebuilt during de-serialization. + self._con = sqlite3.connect(":memory:") + self._create_metadata_table() + + def _create_metadata_table(self): + schema = [ + "id INTEGER PRIMARY KEY", + "component_uuid TEXT", + "component_type TEXT", + "attached_component_uuid TEXT", + "attached_component_type TEXT", + ] + schema_text = ",".join(schema) + cur = self._con.cursor() + execute(cur, f"CREATE TABLE {self.TABLE_NAME}({schema_text})") + execute(cur, f"CREATE INDEX by_ac_uuid ON {self.TABLE_NAME}(attached_component_uuid)") + self._con.commit() + logger.debug("Created in-memory component associations table") + + def add(self, *components: Component): + """Store an association between each component and directly attached subcomponents. + + - Inspects the type of each field of each component's type. Looks for subtypes of + Component and lists of subtypes of Component. + - Does not consider component fields that are dictionaries or other data structures. + """ + rows = [] + for component in components: + for field in type(component).model_fields: + val = getattr(component, field) + if isinstance(val, Component): + rows.append(self._make_row(component, val)) + elif isinstance(val, list) and val and isinstance(val[0], Component): + for item in val: + rows.append(self._make_row(component, item)) + # FUTURE: consider supporting dictionaries like these examples: + # dict[str, Component] + # dict[str, [Component]] + + if rows: + self._insert_rows(rows) + + def clear(self) -> None: + """Clear all component associations.""" + execute(self._con.cursor(), f"DELETE FROM {self.TABLE_NAME}") + logger.info("Cleared all component associations.") + + def list_parent_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[UUID]: + """Return a list of all component UUIDS that compose this component. + For example, return all components connected to a bus. + """ + where_clause = "WHERE attached_component_uuid = ?" + if component_type is None: + params = [str(component.uuid)] + else: + params = [str(component.uuid), component_type.__name__] + where_clause += " AND component_type = ?" + query = f"SELECT component_uuid FROM {self.TABLE_NAME} {where_clause}" + cur = self._con.cursor() + return [UUID(x[0]) for x in execute(cur, query, params)] + + def _insert_rows(self, rows: list[tuple]) -> None: + cur = self._con.cursor() + placeholder = ",".join(["?"] * len(rows[0])) + query = f"INSERT INTO {self.TABLE_NAME} VALUES({placeholder})" + try: + cur.executemany(query, rows) + finally: + self._con.commit() + + @staticmethod + def _make_row(component: Component, attached_component: Component): + return ( + None, + str(component.uuid), + type(component).__name__, + str(attached_component.uuid), + type(attached_component).__name__, + ) diff --git a/src/infrasys/component_manager.py b/src/infrasys/component_manager.py index 88e7c2a..8e5edb1 100644 --- a/src/infrasys/component_manager.py +++ b/src/infrasys/component_manager.py @@ -1,13 +1,19 @@ """Manages components""" -from collections import defaultdict import itertools -from typing import Any, Callable, Iterable, Type +from collections import defaultdict +from typing import Any, Callable, Iterable, Optional, Type from uuid import UUID from loguru import logger from infrasys.component import Component -from infrasys.exceptions import ISAlreadyAttached, ISNotStored, ISOperationNotAllowed +from infrasys.component_associations import ComponentAssociations +from infrasys.exceptions import ( + ISAlreadyAttached, + ISNotStored, + ISOperationNotAllowed, + ISInvalidParameter, +) from infrasys.models import make_label, get_class_and_name_from_label @@ -23,6 +29,7 @@ def __init__( self._components_by_uuid: dict[UUID, Component] = {} self._uuid = uuid self._auto_add_composed_components = auto_add_composed_components + self._associations = ComponentAssociations() @property def auto_add_composed_components(self) -> bool: @@ -34,7 +41,7 @@ def auto_add_composed_components(self, val: bool) -> None: """Set auto_add_composed_components.""" self._auto_add_composed_components = val - def add(self, *args: Component, deserialization_in_progress=False) -> None: + def add(self, *components: Component, deserialization_in_progress=False) -> None: """Add one or more components to the system. Raises @@ -42,9 +49,15 @@ def add(self, *args: Component, deserialization_in_progress=False) -> None: ISAlreadyAttached Raised if a component is already attached to a system. """ - for component in args: + if not components: + msg = "add_associations requires at least one component" + raise ISInvalidParameter(msg) + + for component in components: self._add(component, deserialization_in_progress) + self._associations.add(*components) + def get(self, component_type: Type[Component], name: str) -> Any: """Return the component with the passed type and name. @@ -167,8 +180,22 @@ def iter_all(self) -> Iterable[Any]: """Return an iterator over all components.""" return self._components_by_uuid.values() + def list_parent_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[Component]: + """Return a list of all components that compose this component.""" + return [ + self.get_by_uuid(x) + for x in self._associations.list_parent_components( + component, component_type=component_type + ) + ] + def to_records( - self, component_type: Type[Component], filter_func: Callable | None = None, **kwargs + self, + component_type: Type[Component], + filter_func: Callable | None = None, + **kwargs, ) -> Iterable[dict]: """Return a dictionary representation of the requested components. @@ -207,6 +234,15 @@ def remove(self, component: Component) -> Any: msg = f"{component.label} is not stored" raise ISNotStored(msg) + attached_components = self.list_parent_components(component) + if attached_components: + label = ", ".join((x.label for x in attached_components)) + msg = ( + f"Cannot remove {component.label} because it is attached to these components: " + f"{label}" + ) + raise ISOperationNotAllowed(msg) + container = self._components[component_type][component.name] for i, comp in enumerate(container): if comp.uuid == component.uuid: @@ -259,6 +295,14 @@ def change_uuid(self, component: Component) -> None: msg = "change_component_uuid" raise NotImplementedError(msg) + def rebuild_component_associations(self) -> None: + """Clear the component associations and rebuild the table. This may be necessary + if a user reassigns connected components that are part of a system. + """ + self._associations.clear() + self._associations.add(*self.iter_all()) + logger.info("Rebuilt all component associations.") + def update( self, component_type: Type[Component], @@ -292,6 +336,7 @@ def _add(self, component: Component, deserialization_in_progress: bool) -> None: self._components[cls][name].append(component) self._components_by_uuid[component.uuid] = component + logger.debug("Added {} to the system", component.label) def _check_component_addition(self, component: Component) -> None: @@ -303,7 +348,7 @@ def _check_component_addition(self, component: Component) -> None: self._handle_composed_component(val) # Recurse. self._check_component_addition(val) - if isinstance(val, list) and val and isinstance(val[0], Component): + elif isinstance(val, list) and val and isinstance(val[0], Component): for item in val: self._handle_composed_component(item) # Recurse. diff --git a/src/infrasys/exceptions.py b/src/infrasys/exceptions.py index 8254ac2..0e0c120 100644 --- a/src/infrasys/exceptions.py +++ b/src/infrasys/exceptions.py @@ -18,13 +18,17 @@ class ISFileExists(ISBaseException): class ISConflictingArguments(ISBaseException): - """Raised if the arguments are conflict.""" + """Raised if the arguments conflict.""" class ISConflictingSystem(ISBaseException): """Raised if the system has conflicting values.""" +class ISInvalidParameter(ISBaseException): + """Raised if a parameter is invalid.""" + + class ISNotStored(ISBaseException): """Raised if the requested object is not stored.""" diff --git a/src/infrasys/system.py b/src/infrasys/system.py index f40437e..b595a46 100644 --- a/src/infrasys/system.py +++ b/src/infrasys/system.py @@ -597,6 +597,22 @@ def get_component_types(self) -> Iterable[Type[Component]]: """ return self._component_mgr.get_types() + def list_parent_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[Component]: + """Return a list of all components that compose this component. + + An example usage is where you need to find all components connected to a bus and the Bus + class does not contain that information. The system tracks these connections internally + and can find those components quickly. + + Examples + -------- + >>> components = system.list_parent_components(bus) + >>> print(f"These components are connected to {bus.label}: ", " ".join(components)) + """ + return self._component_mgr.list_parent_components(component, component_type=component_type) + def list_components_by_name(self, component_type: Type[Component], name: str) -> list[Any]: """Return all components that match component_type and name. @@ -625,6 +641,12 @@ def iter_all_components(self) -> Iterable[Any]: """ return self._component_mgr.iter_all() + def rebuild_component_associations(self) -> None: + """Clear the component associations and rebuild the table. This may be necessary + if a user reassigns connected components that are part of a system. + """ + self._component_mgr.rebuild_component_associations() + def remove_component(self, component: Component) -> Any: """Remove the component from the system and return it. @@ -636,6 +658,8 @@ def remove_component(self, component: Component) -> Any: ------ ISNotStored Raised if the component is not stored in the system. + ISOperationNotAllowed + Raised if the other components hold references to this component. Examples -------- diff --git a/tests/test_system.py b/tests/test_system.py index 8dbb96a..4a7cbfc 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -8,6 +8,7 @@ from infrasys.exceptions import ( ISAlreadyAttached, + ISInvalidParameter, ISNotStored, ISOperationNotAllowed, ISConflictingArguments, @@ -31,6 +32,8 @@ def test_system(): gen = SimpleGenerator(name="test-gen", active_power=1.0, rating=1.0, bus=bus, available=True) subsystem = SimpleSubsystem(name="test-subsystem", generators=[gen]) system.add_components(geo, bus, gen, subsystem) + with pytest.raises(ISInvalidParameter): + system.add_components() gen2 = system.get_component(SimpleGenerator, "test-gen") assert gen2 is gen @@ -141,6 +144,53 @@ def test_get_components_multiple_types(): assert len(selected_components) == 2 # 1 SimpleGenerator + 1 RenewableGenerator +def test_component_associations(tmp_path): + system = SimpleSystem() + for i in range(3): + geo = Location(x=i, y=i + 1) + bus = SimpleBus(name=f"bus{i}", voltage=1.1, coordinates=geo) + gen1 = SimpleGenerator( + name=f"gen{i}a", active_power=1.0, rating=1.0, bus=bus, available=True + ) + gen2 = SimpleGenerator( + name=f"gen{i}b", active_power=1.0, rating=1.0, bus=bus, available=True + ) + subsystem = SimpleSubsystem(name=f"test-subsystem{i}", generators=[gen1, gen2]) + system.add_components(geo, bus, gen1, gen2, subsystem) + + def check_attached_components(my_sys): + for i in range(3): + bus = my_sys.get_component(SimpleBus, f"bus{i}") + gen1 = my_sys.get_component(SimpleGenerator, f"gen{i}a") + gen2 = my_sys.get_component(SimpleGenerator, f"gen{i}b") + attached = my_sys.list_parent_components(bus, component_type=SimpleGenerator) + assert len(attached) == 2 + labels = {gen1.label, gen2.label} + for component in attached: + assert component.label in labels + attached_subsystems = my_sys.list_parent_components(component) + assert len(attached_subsystems) == 1 + assert attached_subsystems[0].name == f"test-subsystem{i}" + assert not my_sys.list_parent_components(attached_subsystems[0]) + + for component in (bus, gen1, gen2): + with pytest.raises(ISOperationNotAllowed): + my_sys.remove_component(component) + + check_attached_components(system) + system._component_mgr._associations.clear() + for component in system.iter_all_components(): + assert not system.list_parent_components(component) + + system.rebuild_component_associations() + check_attached_components(system) + + save_dir = tmp_path / "test_system" + system.save(save_dir) + system2 = SimpleSystem.from_json(save_dir / "system.json") + check_attached_components(system2) + + def test_time_series_attach_from_array(): system = SimpleSystem() bus = SimpleBus(name="test-bus", voltage=1.1)